Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nCompiler/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ export(nSeq)
export(nSerialize)
export(nSolve)
export(nSvd)
export(nSwitch)
export(nType)
export(nTypeBasic)
export(nTypeList)
Expand Down
10 changes: 10 additions & 0 deletions nCompiler/R/Rexecution.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ make_nAs_output_dims <- function(input_dims, output_nDim) {
}
}

#' @export
nSwitch <- function(paramID, IDoptions = NULL, ...) {
dotsList <- eval(substitute(alist(...)))
if(is.null(IDoptions)) IDoptions <- seq_along(dotsList)
if(length(IDoptions) != length(dotsList)) stop("length of IDoptions must match number of cases provided in ...")
iUse <- which(IDoptions == paramID)
if(length(iUse) > 0) eval(dotsList[[iUse[1] ]], envir = parent.frame())
invisible(NULL)
}

#' @export
nAs <- function(object, type) {
ttype <- nCaptureType(type)
Expand Down
1 change: 1 addition & 0 deletions nCompiler/R/changeKeywords.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ nKeyWords <- list(as = 'nAs',
c = 'nC',
rep = 'nRep',
seq = 'nSeq',
switch = 'nSwitch',
eigen = 'nEigen',
diag = 'nDiag',
Diagonal = 'nDiagonal', # mirror Matrix::Diagonal to
Expand Down
17 changes: 17 additions & 0 deletions nCompiler/R/compile_aaa_operatorLists.R
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,23 @@ updateOperatorDef('rt_nonstandard', 'matchDef', val = function(n, df = 1, mu = 0
updateOperatorDef('runif', 'matchDef', val = function(n, min = 0, max = 1) {})
updateOperatorDef('rweibull', 'matchDef', val = function(n, shape, scale = 1) {})

assignOperatorDef(
c('nSwitch'),
list(
matchDef = function(expr, IDs, ...) {},
compileArgs = c("IDs"),
simpleTransformations = list(
handler = 'Switch'
),
labelAbstractTypes = list(
handler = 'Switch'
),
cppOutput = list(
handler = 'Switch'
)
)
)

assignOperatorDef(
c('length'), # methods here are for Eigen objects and may be overloaded for user nClasses.
list(
Expand Down
1 change: 1 addition & 0 deletions nCompiler/R/compile_exprClass.R
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ exprClass_put_args_in_order <- function(def, expr,
aux_compileArgs <- if(!is.null(expr$aux[["compileArgs"]])) expr$aux[["compileArgs"]] else list()
for(CA_name in compileArgs) {
if(CA_name %in% names(expr$args)) {
# should we use aux_compileArgs[CA_name] <- list(expr$args[[CA_name]]$Rexpr) to preserve NULL values?
aux_compileArgs[[CA_name]] <- expr$args[[CA_name]]$Rexpr
removeArg(expr, CA_name)
}
Expand Down
22 changes: 21 additions & 1 deletion nCompiler/R/compile_generateCpp.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
## Section for outputting C++ code from an exprClass object ##
##############################################################

# This should be moved into opDefs or handlers
nimCppKeywordsThatFillSemicolon <- c(
'{',
'for',
ifOrWhile,
'nimSwitch',
'nSwitch',
# 'cppLiteral',
'cppComment')

Expand Down Expand Up @@ -206,6 +207,25 @@ inGenCppEnv(
}
)

inGenCppEnv(
Switch <- function(code, symTab) {
IDs <- code$aux$compileArgs$IDs
numChoices <- length(code$args)-1
if(numChoices <= 0) return('')
choicesCode <- vector('list', numChoices)
choiceValues <- IDs
if(length(choiceValues) != numChoices) stop(paste0('number of switch choices does not match number of indices.'))
for(i in 1:numChoices) {
if(code$args[[i+1]]$name != '{')
bracketedCode <- insertExprClassLayer(code, i+1, '{')
choicesCode[[i]] <- list(paste0('case ',choiceValues[i],':'),
compile_generateCpp(code$args[[i+1]], symTab, showBracket=FALSE), 'break;')
}
ans <- list(paste('switch(',code$args[[1]]$name,') {'), choicesCode, '}')
ans
}
)

inGenCppEnv(
Generic_nClass_method_ref <- function(code, symTab) {
paste0('nCompiler::nBind(&', compile_generateCpp(code$args[[2]]), '::',
Expand Down
80 changes: 8 additions & 72 deletions nCompiler/R/compile_labelAbstractTypes.R
Original file line number Diff line number Diff line change
Expand Up @@ -256,78 +256,14 @@ inLabelAbstractTypesEnv(
}
)

# inLabelAbstractTypesEnv(
# nList_doubleBracket <- function(code, symTab, auxEnv, handlingInfo) {
# browser()
# inserts <- NULL
# if(length(inserts) == 0) NULL else inserts
# }
# )

# nCompiler:::inLabelAbstractTypesEnv(
# nClassBuilder <- function(code, symTab, auxEnv, handlingInfo) {
# this_builder <- code$aux$cachedOpInfo$obj_internals
# Rexpr <- code$Rexpr
# args <- as.list(Rexpr)[-1]
# args2 <- c(args, .ID=TRUE)
# ID <- do.call(this_builder, args2)
# NCgen <- NULL
# for(already_built in auxEnv$nClassBuilder_built) {
# if(identical(ID, NCinternals(already_built)$classID)) {
# NCgen <- already_built
# break
# }
# }
# if(is.null(NCgen)) {
# NCgen <- do.call(this_builder, args)
# auxEnv$nClassBuilder_built <- c(auxEnv$nClassBuilder_built, list(NCgen))
# }

# newSym <- symbolNCgenerator$new(name = ID,
# type = ID,
# NCgenerator = NCgen)
# code$type <- newSym
# auxEnv$needed_nClasses <- c(auxEnv$needed_nClasses, NCgen)
# NULL
# }
# )

# inLabelAbstractTypesEnv(
# CheckOverload <- function(code, symTab, auxEnv, handlingInfo) {
# if(length(code$args) == 0) return(NULL)
# arg1 <- code$args[[1]]
# if(inherits(arg1$type, "symbolNC")) {
# overload <- NC_find_overload(arg1$type$NCgenerator, code$name, "labelAbstractTypes", inherits=TRUE)
# if(!is.null(overload)) {
# if(is.function(overload))
# ans <- overload(code, symTab, auxEnv, handlingInfo)
# else
# ans <- eval(call(overload, code, symTab, auxEnv, handlingInfo),
# envir = labelAbstractTypesEnv)
# return(ans)
# }
# }
# NULL
# }
# )

# inLabelAbstractTypesEnv(
# recurse_labelAbstractTypes_overloaded <- function(code, symTab, auxEnv, handlingInfo) {
# useArgs <- rep(FALSE, length(code$args))
# useArgs[1] <- TRUE
# inserts <- recurse_labelAbstractTypes(code, symTab, auxEnv,
# handlingInfo, useArgs = useArgs)
# inserts2 <- CheckOverload(code, symTab, auxEnv, handlingInfo)
# handled <- TRUE
# if(is.null(inserts2)) {
# inserts2 <- recurse_labelAbstractTypes(code, symTab, auxEnv,
# handlingInfo, useArgs = !useArgs)
# handled <- FALSE
# }
# if(isTRUE(inserts2)) inserts2 <- NULL
# list(inserts = c(inserts, inserts2), handled = handled)
# }
# )
inLabelAbstractTypesEnv(
Switch <- function(code, symTab, auxEnv, handlingInfo) {
inserts <- recurse_labelAbstractTypes(code, symTab, auxEnv,
handlingInfo)
code$type <- "NA" # should never be looked at because Switch has no return type
if(length(inserts) == 0) NULL else inserts
}
)

nCompiler:::inLabelAbstractTypesEnv(
DoubleBracket <- function(code, symTab, auxEnv, handlingInfo) {
Expand Down
13 changes: 13 additions & 0 deletions nCompiler/R/compile_simpleTransformations.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,17 @@ inSimpleTransformationsEnv(
drop_arg <- !isFALSE(drop_arg)
code$aux$compileArgs$drop <- drop_arg
}
)

inSimpleTransformationsEnv(
Switch <- function(code, symTab, auxEnv, handlingInfo) {
if(length(code$args) < 2) stop("nSwitch must have at least 2 arguments: the value to check and at least one option.")
IDs <- code$aux$compileArgs$IDs
if(is.null(IDs)) IDs <- 1:(length(code$args) - 1)
else IDs <- eval(IDs, envir = auxEnv$where)
if(!is.numeric(IDs)) stop("IDs for nSwitch must be numeric.")
if(length(IDs) != length(code$args) - 1) stop("Number of IDs for nSwitch must match the number of options.")
code$aux$compileArgs$IDs <- IDs
if(code$caller$name != "{") stop("nSwitch can not be used within an expression. It does not return anything.")
}
)
52 changes: 52 additions & 0 deletions nCompiler/tests/testthat/specificOp_tests/test-switch.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
library(nCompiler)
library(testthat)

test_that("nSwitch works", {
foo <- nFunction(
function(opt = integerScalar()) {
x <- 0;
switch(opt, 1:2, x<-1, x<-2)
return(x)
},
returnType = 'numericScalar'
)
# check that my_IDs is found by scoping
# and that an option can give a {} set of code
layer <- function() {
my_IDs <- 5:6
foo2 <- nFunction(
function(opt = integerScalar()) {
x <- 0;
switch(opt, my_IDs, {x<-x+1; x<-x+4}, x<-6)
return(x)
},
returnType = 'numericScalar'
)
foo2
}
foo2 <- layer()

comp <- nCompile(foo, foo2)
expect_equal(foo(2), 2)
expect_equal(foo2(5), 5)
expect_equal(foo2(6), 6)
expect_equal(foo2(4), 0)

expect_equal(comp$foo(2), 2)
expect_equal(comp$foo2(5), 5)
expect_equal(comp$foo2(6), 6)
expect_equal(comp$foo2(4), 0)

foo_error1 <-
foo <- nFunction(
function(opt = integerScalar()) {
x <- 0;
switch(opt, 1:3, x<-1, x<-2)
return(x)
},
returnType = 'numericScalar'
)
cat("expecting two error messages about number of IDs for nSwitch not matching number of options:")
expect_error(nCompile(foo_error1))
expect_error(foo_error1(3))
})
Loading