2

I am using the mlr package to predict from an SVM. If my validation set contains factor levels not present in my training data, the prediction fails, regardless of how I set fix.factors.prediction when making the SVM learner.

What is the proper way to handle this? Using e1071::svm() will return a response for new factor levels, but how can I do the same with mlr methods?

Example

library(mlr)
library(dplyr)

set.seed(575)
data(iris)

# Split data
train_set <- sample_frac(iris, 4/5)
valid_set <- setdiff(iris, train_set)

# Remove all "setosa" values from the training set
train_set[train_set$Species == "setosa", "Species"] <- 
  sample(c("virginica", "versicolor"), 
         sum(train_set$Species == "setosa"), replace = TRUE)    
# Fit model
iris_task <- makeRegrTask(data = train_set, target = "Petal.Width")

svm_lrn <- makeLearner("regr.svm", fix.factors.prediction = TRUE)

svm_mod <- train(svm_lrn, iris_task)

# Predict on new factor levels
predict(svm_mod, newdata = valid_set)

Error in (function (..., row.names = NULL, check.rows = FALSE, check.names = TRUE, : arguments imply differing number of rows: 29, 20

When using makeLearner("regr.svm", fix.factors.prediction = FALSE), I get the following error from the call to predict:

Error in scale.default(newdata[, object$scaled, drop = FALSE], center = object$x.scale$"scaled:center", : length of 'center' must equal the number of columns of 'x'

Things that do work

I can generate predictions when subsetting to factor levels in the training set:

predict(svm_mod, newdata = valid_set %>% 
          filter(Species %in% train_set$Species))

No error when using a different learner:

nnet_lrn <- makeLearner("regr.nnet", fix.factors.prediction = TRUE)
nnet_mod <- train(nnet_lrn, iris_task)
predict(nnet_mod, newdata = valid_set)

Or when using the same learner directly from the package:

e1071_mod <- 
  e1071::svm(Petal.Width ~ Sepal.Length + Sepal.Width +
               Petal.Length + Species, train_set)
predict(e1071_mod, newdata = valid_set)

Session info

R version 3.4.4 (2018-03-15)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 14.04.6 LTS

Matrix products: default
BLAS: /usr/lib/libblas/libblas.so.3.0
LAPACK: /usr/lib/lapack/liblapack.so.3.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] dplyr_0.8.0.1     mlr_2.14.0.9000   ParamHelpers_1.12

loaded via a namespace (and not attached):
 [1] parallelMap_1.4    Rcpp_1.0.1         pillar_1.4.1      
 [4] compiler_3.4.4     class_7.3-14       tools_3.4.4       
 [7] tibble_2.1.3       gtable_0.3.0       checkmate_1.9.3   
[10] lattice_0.20-38    pkgconfig_2.0.2    rlang_0.3.99.9003 
[13] Matrix_1.2-14      fastmatch_1.1-0    rstudioapi_0.8    
[16] yaml_2.2.0         parallel_3.4.4     e1071_1.7-1       
[19] nnet_7.3-12        grid_3.4.4         tidyselect_0.2.5  
[22] glue_1.3.1         data.table_1.12.2  R6_2.4.0          
[25] XML_3.98-1.20      survival_2.41-3    ggplot2_3.2.0.9000
[28] purrr_0.3.2        magrittr_1.5       backports_1.1.4   
[31] scales_1.0.0.9000  BBmisc_1.11        splines_3.4.4     
[34] assertthat_0.2.1   colorspace_1.3-2   stringi_1.4.3     
[37] lazyeval_0.2.2     munsell_0.5.0      crayon_1.3.4 
  • Can you change your code to a full reprex? It does not run in its current state. I'm using e1071::svm() myself with missing factors and never had problems so far. – pat-s Jun 12 at 8:30
  • @pat-s Sorry, I miscopied the line replacing setosa values. Everything else seems to run as I was expecting, though. Please let me know if you're still having trouble reproducing the example. – coletl Jun 12 at 12:54
2

Ok, this has been a little challenging. A few things upfront:

  • e1071::svm() cannot handle missing factor levels in newdata (Error in predict.svm: test data does not match model)
  • The manual execution of your example only runs because you did not drop the unused factor levels in train_data
  • argument fix.factor.predictions did not do what its supposed to. I posted a temporary fix in this branch. The fix is very dirty and just a proof of concept. I might clean it up.

Proof of non-working manual execution:

library(mlr)
#> Loading required package: ParamHelpers
#> Registered S3 methods overwritten by 'ggplot2':
#>   method         from 
#>   [.quosures     rlang
#>   c.quosures     rlang
#>   print.quosures rlang
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union

set.seed(575)
data(iris)

# Split data
train_set <- sample_frac(iris, 4 / 5)
valid_set <- setdiff(iris, train_set)

# Remove all "setosa" values from the training set
train_set[train_set$Species == "setosa", "Species"] <-
  sample(c("virginica", "versicolor"),
    sum(train_set$Species == "setosa"), replace = TRUE)

# this is important
train_set = droplevels(train_set)

e1071_mod <- e1071::svm(Petal.Width ~ Sepal.Length + Sepal.Width +
  Petal.Length + Species, train_set)
predict(e1071_mod, newdata = valid_set)
#> Error in scale.default(newdata[, object$scaled, drop = FALSE], center = object$x.scale$"scaled:center", : length of 'center' must equal the number of columns of 'x'

Created on 2019-06-13 by the reprex package (v0.3.0)

Working example using the provided fix in mlr:

remotes::install_github("mlr-org/[email protected]")
#> Downloading GitHub repo mlr-org/[email protected]
library(mlr)
#> Loading required package: ParamHelpers
#> Registered S3 methods overwritten by 'ggplot2':
#>   method         from 
#>   [.quosures     rlang
#>   c.quosures     rlang
#>   print.quosures rlang
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union

set.seed(575)
data(iris)

# Split data
train_set <- sample_frac(iris, 4 / 5)
valid_set <- setdiff(iris, train_set)

# Remove all "setosa" values from the training set
train_set[train_set$Species == "setosa", "Species"] <-
  sample(c("virginica", "versicolor"),
    sum(train_set$Species == "setosa"), replace = TRUE)

# this is important
train_set = droplevels(train_set)

# Fit model
iris_task <- makeRegrTask(data = train_set, target = "Petal.Width")

svm_lrn <- makeLearner("regr.svm", fix.factors.prediction = TRUE)

svm_mod <- train(svm_lrn, iris_task)

# Predict on new factor levels
predict(svm_mod, newdata = valid_set)
#> Prediction: 30 observations
#> predict.type: response
#> threshold: 
#> time: 0.00
#>   truth  response
#> 1   0.3 0.2457751
#> 2   0.1 0.2730398
#> 3   0.2 0.2717464
#> 4   0.1 0.2717748
#> 5   0.1 0.2651599
#> 6   0.4 0.2582568
#> ... (#rows: 30, #cols: 2)

Created on 2019-06-13 by the reprex package (v0.3.0)

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service, privacy policy and cookie policy

Not the answer you're looking for? Browse other questions tagged or ask your own question.