From 9c35a97ddfa0df660c5a08c64327cff09003b450 Mon Sep 17 00:00:00 2001 From: perrydv Date: Fri, 5 Jun 2026 14:45:20 -0700 Subject: [PATCH] support switch (which becomes nSwitch) --- nCompiler/NAMESPACE | 1 + nCompiler/R/Rexecution.R | 10 +++ nCompiler/R/changeKeywords.R | 1 + nCompiler/R/compile_aaa_operatorLists.R | 17 ++++ nCompiler/R/compile_exprClass.R | 1 + nCompiler/R/compile_generateCpp.R | 22 ++++- nCompiler/R/compile_labelAbstractTypes.R | 80 ++----------------- nCompiler/R/compile_simpleTransformations.R | 13 +++ .../testthat/specificOp_tests/test-switch.R | 52 ++++++++++++ 9 files changed, 124 insertions(+), 73 deletions(-) create mode 100644 nCompiler/tests/testthat/specificOp_tests/test-switch.R diff --git a/nCompiler/NAMESPACE b/nCompiler/NAMESPACE index f0294dee..d8ae30ed 100644 --- a/nCompiler/NAMESPACE +++ b/nCompiler/NAMESPACE @@ -92,6 +92,7 @@ export(nSeq) export(nSerialize) export(nSolve) export(nSvd) +export(nSwitch) export(nType) export(nTypeBasic) export(nTypeList) diff --git a/nCompiler/R/Rexecution.R b/nCompiler/R/Rexecution.R index 416963ef..fabae62a 100644 --- a/nCompiler/R/Rexecution.R +++ b/nCompiler/R/Rexecution.R @@ -29,6 +29,16 @@ make_nAs_output_dims <- function(input_dims, output_nDim) { } } +#' @export +nSwitch <- function(paramID, IDoptions = NULL, ...) { + dotsList <- eval(substitute(alist(...))) + if(is.null(IDoptions)) IDoptions <- seq_along(dotsList) + if(length(IDoptions) != length(dotsList)) stop("length of IDoptions must match number of cases provided in ...") + iUse <- which(IDoptions == paramID) + if(length(iUse) > 0) eval(dotsList[[iUse[1] ]], envir = parent.frame()) + invisible(NULL) +} + #' @export nAs <- function(object, type) { ttype <- nCaptureType(type) diff --git a/nCompiler/R/changeKeywords.R b/nCompiler/R/changeKeywords.R index a461bd3c..326ce6f3 100644 --- a/nCompiler/R/changeKeywords.R +++ b/nCompiler/R/changeKeywords.R @@ -15,6 +15,7 @@ nKeyWords <- list(as = 'nAs', c = 'nC', rep = 'nRep', seq = 'nSeq', + switch = 'nSwitch', eigen = 'nEigen', diag = 'nDiag', Diagonal = 'nDiagonal', # mirror Matrix::Diagonal to diff --git a/nCompiler/R/compile_aaa_operatorLists.R b/nCompiler/R/compile_aaa_operatorLists.R index eccfdf3f..86cc80ee 100644 --- a/nCompiler/R/compile_aaa_operatorLists.R +++ b/nCompiler/R/compile_aaa_operatorLists.R @@ -1122,6 +1122,23 @@ updateOperatorDef('rt_nonstandard', 'matchDef', val = function(n, df = 1, mu = 0 updateOperatorDef('runif', 'matchDef', val = function(n, min = 0, max = 1) {}) updateOperatorDef('rweibull', 'matchDef', val = function(n, shape, scale = 1) {}) +assignOperatorDef( + c('nSwitch'), + list( + matchDef = function(expr, IDs, ...) {}, + compileArgs = c("IDs"), + simpleTransformations = list( + handler = 'Switch' + ), + labelAbstractTypes = list( + handler = 'Switch' + ), + cppOutput = list( + handler = 'Switch' + ) + ) +) + assignOperatorDef( c('length'), # methods here are for Eigen objects and may be overloaded for user nClasses. list( diff --git a/nCompiler/R/compile_exprClass.R b/nCompiler/R/compile_exprClass.R index a9fff29b..49565624 100644 --- a/nCompiler/R/compile_exprClass.R +++ b/nCompiler/R/compile_exprClass.R @@ -595,6 +595,7 @@ exprClass_put_args_in_order <- function(def, expr, aux_compileArgs <- if(!is.null(expr$aux[["compileArgs"]])) expr$aux[["compileArgs"]] else list() for(CA_name in compileArgs) { if(CA_name %in% names(expr$args)) { + # should we use aux_compileArgs[CA_name] <- list(expr$args[[CA_name]]$Rexpr) to preserve NULL values? aux_compileArgs[[CA_name]] <- expr$args[[CA_name]]$Rexpr removeArg(expr, CA_name) } diff --git a/nCompiler/R/compile_generateCpp.R b/nCompiler/R/compile_generateCpp.R index 1f426ff9..b691c97f 100644 --- a/nCompiler/R/compile_generateCpp.R +++ b/nCompiler/R/compile_generateCpp.R @@ -2,11 +2,12 @@ ## Section for outputting C++ code from an exprClass object ## ############################################################## +# This should be moved into opDefs or handlers nimCppKeywordsThatFillSemicolon <- c( '{', 'for', ifOrWhile, - 'nimSwitch', + 'nSwitch', # 'cppLiteral', 'cppComment') @@ -206,6 +207,25 @@ inGenCppEnv( } ) +inGenCppEnv( + Switch <- function(code, symTab) { + IDs <- code$aux$compileArgs$IDs + numChoices <- length(code$args)-1 + if(numChoices <= 0) return('') + choicesCode <- vector('list', numChoices) + choiceValues <- IDs + if(length(choiceValues) != numChoices) stop(paste0('number of switch choices does not match number of indices.')) + for(i in 1:numChoices) { + if(code$args[[i+1]]$name != '{') + bracketedCode <- insertExprClassLayer(code, i+1, '{') + choicesCode[[i]] <- list(paste0('case ',choiceValues[i],':'), + compile_generateCpp(code$args[[i+1]], symTab, showBracket=FALSE), 'break;') + } + ans <- list(paste('switch(',code$args[[1]]$name,') {'), choicesCode, '}') + ans + } +) + inGenCppEnv( Generic_nClass_method_ref <- function(code, symTab) { paste0('nCompiler::nBind(&', compile_generateCpp(code$args[[2]]), '::', diff --git a/nCompiler/R/compile_labelAbstractTypes.R b/nCompiler/R/compile_labelAbstractTypes.R index 0baf05cc..0bd156d7 100644 --- a/nCompiler/R/compile_labelAbstractTypes.R +++ b/nCompiler/R/compile_labelAbstractTypes.R @@ -256,78 +256,14 @@ inLabelAbstractTypesEnv( } ) -# inLabelAbstractTypesEnv( -# nList_doubleBracket <- function(code, symTab, auxEnv, handlingInfo) { -# browser() -# inserts <- NULL -# if(length(inserts) == 0) NULL else inserts -# } -# ) - -# nCompiler:::inLabelAbstractTypesEnv( -# nClassBuilder <- function(code, symTab, auxEnv, handlingInfo) { -# this_builder <- code$aux$cachedOpInfo$obj_internals -# Rexpr <- code$Rexpr -# args <- as.list(Rexpr)[-1] -# args2 <- c(args, .ID=TRUE) -# ID <- do.call(this_builder, args2) -# NCgen <- NULL -# for(already_built in auxEnv$nClassBuilder_built) { -# if(identical(ID, NCinternals(already_built)$classID)) { -# NCgen <- already_built -# break -# } -# } -# if(is.null(NCgen)) { -# NCgen <- do.call(this_builder, args) -# auxEnv$nClassBuilder_built <- c(auxEnv$nClassBuilder_built, list(NCgen)) -# } - -# newSym <- symbolNCgenerator$new(name = ID, -# type = ID, -# NCgenerator = NCgen) -# code$type <- newSym -# auxEnv$needed_nClasses <- c(auxEnv$needed_nClasses, NCgen) -# NULL -# } -# ) - -# inLabelAbstractTypesEnv( -# CheckOverload <- function(code, symTab, auxEnv, handlingInfo) { -# if(length(code$args) == 0) return(NULL) -# arg1 <- code$args[[1]] -# if(inherits(arg1$type, "symbolNC")) { -# overload <- NC_find_overload(arg1$type$NCgenerator, code$name, "labelAbstractTypes", inherits=TRUE) -# if(!is.null(overload)) { -# if(is.function(overload)) -# ans <- overload(code, symTab, auxEnv, handlingInfo) -# else -# ans <- eval(call(overload, code, symTab, auxEnv, handlingInfo), -# envir = labelAbstractTypesEnv) -# return(ans) -# } -# } -# NULL -# } -# ) - -# inLabelAbstractTypesEnv( -# recurse_labelAbstractTypes_overloaded <- function(code, symTab, auxEnv, handlingInfo) { -# useArgs <- rep(FALSE, length(code$args)) -# useArgs[1] <- TRUE -# inserts <- recurse_labelAbstractTypes(code, symTab, auxEnv, -# handlingInfo, useArgs = useArgs) -# inserts2 <- CheckOverload(code, symTab, auxEnv, handlingInfo) -# handled <- TRUE -# if(is.null(inserts2)) { -# inserts2 <- recurse_labelAbstractTypes(code, symTab, auxEnv, -# handlingInfo, useArgs = !useArgs) -# handled <- FALSE -# } -# if(isTRUE(inserts2)) inserts2 <- NULL -# list(inserts = c(inserts, inserts2), handled = handled) -# } -# ) +inLabelAbstractTypesEnv( + Switch <- function(code, symTab, auxEnv, handlingInfo) { + inserts <- recurse_labelAbstractTypes(code, symTab, auxEnv, + handlingInfo) + code$type <- "NA" # should never be looked at because Switch has no return type + if(length(inserts) == 0) NULL else inserts + } +) nCompiler:::inLabelAbstractTypesEnv( DoubleBracket <- function(code, symTab, auxEnv, handlingInfo) { diff --git a/nCompiler/R/compile_simpleTransformations.R b/nCompiler/R/compile_simpleTransformations.R index 1693c638..171bbb8e 100644 --- a/nCompiler/R/compile_simpleTransformations.R +++ b/nCompiler/R/compile_simpleTransformations.R @@ -143,4 +143,17 @@ inSimpleTransformationsEnv( drop_arg <- !isFALSE(drop_arg) code$aux$compileArgs$drop <- drop_arg } +) + +inSimpleTransformationsEnv( + Switch <- function(code, symTab, auxEnv, handlingInfo) { + if(length(code$args) < 2) stop("nSwitch must have at least 2 arguments: the value to check and at least one option.") + IDs <- code$aux$compileArgs$IDs + if(is.null(IDs)) IDs <- 1:(length(code$args) - 1) + else IDs <- eval(IDs, envir = auxEnv$where) + if(!is.numeric(IDs)) stop("IDs for nSwitch must be numeric.") + if(length(IDs) != length(code$args) - 1) stop("Number of IDs for nSwitch must match the number of options.") + code$aux$compileArgs$IDs <- IDs + if(code$caller$name != "{") stop("nSwitch can not be used within an expression. It does not return anything.") + } ) \ No newline at end of file diff --git a/nCompiler/tests/testthat/specificOp_tests/test-switch.R b/nCompiler/tests/testthat/specificOp_tests/test-switch.R new file mode 100644 index 00000000..307543d0 --- /dev/null +++ b/nCompiler/tests/testthat/specificOp_tests/test-switch.R @@ -0,0 +1,52 @@ +library(nCompiler) +library(testthat) + +test_that("nSwitch works", { + foo <- nFunction( + function(opt = integerScalar()) { + x <- 0; + switch(opt, 1:2, x<-1, x<-2) + return(x) + }, + returnType = 'numericScalar' + ) + # check that my_IDs is found by scoping + # and that an option can give a {} set of code + layer <- function() { + my_IDs <- 5:6 + foo2 <- nFunction( + function(opt = integerScalar()) { + x <- 0; + switch(opt, my_IDs, {x<-x+1; x<-x+4}, x<-6) + return(x) + }, + returnType = 'numericScalar' + ) + foo2 + } + foo2 <- layer() + + comp <- nCompile(foo, foo2) + expect_equal(foo(2), 2) + expect_equal(foo2(5), 5) + expect_equal(foo2(6), 6) + expect_equal(foo2(4), 0) + + expect_equal(comp$foo(2), 2) + expect_equal(comp$foo2(5), 5) + expect_equal(comp$foo2(6), 6) + expect_equal(comp$foo2(4), 0) + + foo_error1 <- + foo <- nFunction( + function(opt = integerScalar()) { + x <- 0; + switch(opt, 1:3, x<-1, x<-2) + return(x) + }, + returnType = 'numericScalar' + ) + cat("expecting two error messages about number of IDs for nSwitch not matching number of options:") + expect_error(nCompile(foo_error1)) + expect_error(foo_error1(3)) +})