J. Taroni 2018
caret
# devtools::install_github('topepo/caret/pkg/caret',
# ref = "6546939345fe10649cefcbfee55d58fb682bc902")
# devtools::install_version("e1071", version = "1.6-8")
# magrittr pipe
`%>%` <- dplyr::`%>%`
# plot and result directory setup for this notebook
plot.dir <- file.path("plots", "25")
dir.create(plot.dir, recursive = TRUE, showWarnings = FALSE)
results.dir <- file.path("results", "25")
dir.create(results.dir, recursive = TRUE, showWarnings = FALSE)
covariate.df <- readr::read_tsv(file.path("data", "rtx",
"RTX_full_covariates.tsv"))
Parsed with column specification:
cols(
.default = col_character(),
barcode = col_integer(),
AGE = col_integer(),
bcells = col_double(),
HGB = col_double(),
`Platelet Count` = col_double(),
WBC = col_double(),
Lymphs = col_double(),
Neutrophils = col_double(),
Eosinophils = col_double(),
Tscore = col_integer()
)
See spec(...) for full column specifications.
This is gene-level expression data that has been vst-transformed and filtered to only genes that are in the recount2 PLIER model.
exprs <- readRDS(file.path("data", "rtx", "VST_blind_filtered.RDS"))
B
The multiPLIER approach
recount.b <- readRDS(file.path("data", "rtx", "RTX_recount2_B.RDS"))
rtx.plier <- readRDS(file.path("data", "rtx", "RTX_PLIER_model.RDS"))
rtx.b <- rtx.plier$B
First, we’ll change the sample names to match the barcodes in the covariates. The first six characters of the current column/sample names should correspond to a barcode.
# in the expression data
colnames(exprs) <- substr(colnames(exprs), start = 1, stop = 6)
all(covariate.df$barcode == colnames(exprs))
[1] TRUE
# in the recount B data
colnames(recount.b) <- substr(colnames(recount.b), start = 1, stop = 6)
all(covariate.df$barcode == colnames(recount.b))
[1] TRUE
# in the RTX B
colnames(rtx.b) <- substr(colnames(rtx.b), start = 1, stop = 6)
all(covariate.df$barcode == colnames(rtx.b))
[1] TRUE
The mainclass
column in covariate.df
is what we are interested in predicting; it contains whether or not a patient is a nonresponder or responder (divided into tolerant or nontolerant depending on, I believe, long-term outcome) to treatment. (We’ll exclude samples with NA
in this column.)
We’ll want to try and predict this from baseline samples (covariate.df$timepoint == "BL"
). We will not be adjusting for covariates at this point. The earlier publications on this trial suggest that the majority of covariates have no significant association with response.
Let’s take a look at the sample size and class balance.
table(covariate.df$mainclass, covariate.df$timepoint)
BL M18 M6
Nonresponder 14 10 7
Nontolerant 10 3 10
Tolerant 12 12 12
We can see that there are 37 baseline samples and that the three classes (Nonresponder
, Nontolerant
, and Tolerant
) are pretty balanced. If we use these three classes, we can likely use a metric like total accuracy to evaluate performance. Also, the small sample size lends itself to leave-one-out cross-validataion (LOOCV).
# Do all baseline samples have response labels? No, one is NA
baseline.covariate.df <- covariate.df %>%
dplyr::filter(timepoint == "BL") %>%
dplyr::select(c("barcode", "timepoint", "mainclass")) %>%
dplyr::filter(complete.cases(.))
# we only want the baseline samples with a class label
baseline.samples <- baseline.covariate.df$barcode
baseline.exprs <- t(exprs[, which(colnames(exprs) %in% baseline.samples)])
dim(baseline.exprs)
[1] 36 6690
recount.baseline.b <-
t(recount.b[, which(colnames(recount.b) %in% baseline.samples)])
dim(recount.baseline.b)
[1] 36 987
rtx.baseline.b <- t(rtx.b[, which(colnames(rtx.b) %in% baseline.samples)])
dim(rtx.baseline.b)
[1] 36 23
all(rownames(recount.baseline.b) == baseline.covariate.df$barcode)
[1] TRUE
all(rownames(baseline.exprs) == baseline.covariate.df$barcode)
[1] TRUE
all(rownames(rtx.baseline.b) == baseline.covariate.df$barcode)
[1] TRUE
set.seed(12345)
exprs.results <- glmnet::cv.glmnet(x = baseline.exprs,
y = baseline.covariate.df$mainclass,
type.measure = "class",
family = "multinomial",
nfolds = nrow(baseline.exprs)) # LOOCV
Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per fold
saveRDS(exprs.results, file.path(results.dir, "expression_cv.glmnet.RDS"))
exprs.predicted.labels <- stats::predict(exprs.results,
baseline.exprs,
s = exprs.results$lambda.1se,
type = "class")
caret::confusionMatrix(data = as.factor(exprs.predicted.labels),
reference = as.factor(baseline.covariate.df$mainclass))
Confusion Matrix and Statistics
Reference
Prediction Nonresponder Nontolerant Tolerant
Nonresponder 13 1 0
Nontolerant 0 9 0
Tolerant 1 0 12
Overall Statistics
Accuracy : 0.9444
95% CI : (0.8134, 0.9932)
No Information Rate : 0.3889
P-Value [Acc > NIR] : 2.763e-12
Kappa : 0.9157
Mcnemar's Test P-Value : NA
Statistics by Class:
Class: Nonresponder Class: Nontolerant Class: Tolerant
Sensitivity 0.9286 0.9000 1.0000
Specificity 0.9545 1.0000 0.9583
Pos Pred Value 0.9286 1.0000 0.9231
Neg Pred Value 0.9545 0.9630 1.0000
Prevalence 0.3889 0.2778 0.3333
Detection Rate 0.3611 0.2500 0.3333
Detection Prevalence 0.3889 0.2500 0.3611
Balanced Accuracy 0.9416 0.9500 0.9792
B
recount.b.results <- glmnet::cv.glmnet(x = recount.baseline.b,
y = baseline.covariate.df$mainclass,
type.measure = "class",
family = "multinomial",
nfolds = nrow(recount.baseline.b)) # LOOCV
Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per fold
saveRDS(recount.b.results, file.path(results.dir, "recount2_B_cv.glmnet.RDS"))
recount.b.predicted.labels <- stats::predict(recount.b.results,
recount.baseline.b,
s = recount.b.results$lambda.1se,
type = "class")
caret::confusionMatrix(data = as.factor(recount.b.predicted.labels),
reference = as.factor(baseline.covariate.df$mainclass))
Levels are not in the same order for reference and data. Refactoring data to match.
Confusion Matrix and Statistics
Reference
Prediction Nonresponder Nontolerant Tolerant
Nonresponder 14 10 12
Nontolerant 0 0 0
Tolerant 0 0 0
Overall Statistics
Accuracy : 0.3889
95% CI : (0.2314, 0.5654)
No Information Rate : 0.3889
P-Value [Acc > NIR] : 0.5628
Kappa : 0
Mcnemar's Test P-Value : NA
Statistics by Class:
Class: Nonresponder Class: Nontolerant Class: Tolerant
Sensitivity 1.0000 0.0000 0.0000
Specificity 0.0000 1.0000 1.0000
Pos Pred Value 0.3889 NaN NaN
Neg Pred Value NaN 0.7222 0.6667
Prevalence 0.3889 0.2778 0.3333
Detection Rate 0.3889 0.0000 0.0000
Detection Prevalence 1.0000 0.0000 0.0000
Balanced Accuracy 0.5000 0.5000 0.5000
B
rtx.b.results <- glmnet::cv.glmnet(x = rtx.baseline.b,
y = baseline.covariate.df$mainclass,
type.measure = "class",
family = "multinomial",
nfolds = nrow(rtx.baseline.b)) # LOOCV
Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per fold
saveRDS(rtx.b.results, file.path(results.dir, "RTX_B_cv.glmnet.RDS"))
rtx.b.predicted.labels <- stats::predict(rtx.b.results,
rtx.baseline.b,
s = rtx.b.results$lambda.1se,
type = "class")
caret::confusionMatrix(data = as.factor(rtx.b.predicted.labels),
reference = as.factor(baseline.covariate.df$mainclass))
Confusion Matrix and Statistics
Reference
Prediction Nonresponder Nontolerant Tolerant
Nonresponder 14 0 0
Nontolerant 0 10 0
Tolerant 0 0 12
Overall Statistics
Accuracy : 1
95% CI : (0.9026, 1)
No Information Rate : 0.3889
P-Value [Acc > NIR] : 1.713e-15
Kappa : 1
Mcnemar's Test P-Value : NA
Statistics by Class:
Class: Nonresponder Class: Nontolerant Class: Tolerant
Sensitivity 1.0000 1.0000 1.0000
Specificity 1.0000 1.0000 1.0000
Pos Pred Value 1.0000 1.0000 1.0000
Neg Pred Value 1.0000 1.0000 1.0000
Prevalence 0.3889 0.2778 0.3333
Detection Rate 0.3889 0.2778 0.3333
Detection Prevalence 0.3889 0.2778 0.3333
Balanced Accuracy 1.0000 1.0000 1.0000
acc.df <- data.frame(Model = c("Expression", "RTX LVs", "multiPLIER LVs"),
Accuracy = c(0.9444, 1, 0.3889),
Lower = c(0.8134, 0.9026, 0.2314),
Upper = c(0.9932, 1, 0.5654))
acc.df %>%
ggplot2::ggplot() +
ggplot2::geom_pointrange(mapping = ggplot2::aes(x = Model, y = Accuracy,
ymin = Lower, ymax = Upper)) +
ggplot2::theme_bw() +
ggplot2::labs(title = "Predicting response with LASSO") +
ggplot2::theme(plot.title = ggplot2::element_text(hjust = 0.5,
face = "bold")) +
ggplot2::theme(text = ggplot2::element_text(size = 15))
ggplot2::ggsave(file.path(plot.dir, "total_accuracy_CI.pdf"),
plot = ggplot2::last_plot())
Saving 7 x 7 in image
I wonder if the poor performance in the case of the multiPLIER LVs could be due to a smaller range of values.
summary(as.vector(baseline.exprs))
Min. 1st Qu. Median Mean 3rd Qu. Max.
3.399 6.428 8.460 8.419 10.282 19.637
summary(as.vector(recount.baseline.b))
Min. 1st Qu. Median Mean 3rd Qu. Max.
-0.4666653 -0.0223182 -0.0020393 -0.0009504 0.0183729 0.8213279
summary(as.vector(rtx.baseline.b))
Min. 1st Qu. Median Mean 3rd Qu. Max.
-2.612200 -0.470965 -0.123762 0.004056 0.327775 5.092082