User-specific Models: PRISM

Thomas Jemielita

One advantage of PRISM is the flexibility to adjust each step of the algorithm and also to input user-created functions/models. This facilitates faster testing and experimentation. First, let’s simulate the continuous data again.

library(StratifiedMedicine)
dat_ctns = generate_subgrp_data(family="gaussian")
Y = dat_ctns$Y
X = dat_ctns$X # 50 covariates, 46 are noise variables, X1 and X2 are truly predictive
A = dat_ctns$A # binary treatment, 1:1 randomized 

Next, before we illustrate how to implement user-specific models in PRISM, let’s highlight the key outputs at each step.

Key Outputs by Model
Model Required.Outputs Description
filter filter.vars Variables that pass filter
ple list(mod,pred.fun) Model fit(s) and prediction function
submod list(mod,pred.fun) Model fit(s) and prediction function
param param.dat Parameter Estimates (overall and subgroups)

For the filter model, the only required output is a vector of variable names that pass the filter (for example: covariates with non-zero coefficients in elastic net model). For the patient-level estimate (ple) model and subgroup model (submod), the required outputs are the model fit(s) and an associated prediction function. The prediction function can also be swapped with pre-computed predictions (details below). Lastly, for parameter estimation (param), the only required output is “param.dat”, which is a data frame of parameter estimates/SEs/CIs for the overall population and the identified subgroups (if any). Template functions for each step can also be generated using the model_template() function within the StratifiedMedicine package.

Filter Model (filter)

The template filter function is:

model_template(type="filter")
#> function (Y, A, X, ...) 
#> {
#> }
#> <bytecode: 0x000000001b0a3b40>
#> <environment: 0x000000001b0a4320>

Note that the filter uses the observed data (Y,A,X), which are required inputs, and outputs an object called “filter.vars.” This needs to contain the variable names of the variables that pass the filtering step. For example, consider the lasso:

filter_lasso = function(Y, A, X, lambda="lambda.min", family="gaussian", ...){
  require(glmnet)
  ## Model matrix X matrix #
  X = model.matrix(~. -1, data = X )

  ##### Elastic Net on estimated ITEs #####
  set.seed(6134)
  if (family=="survival") { family = "cox"  }
  mod <- cv.glmnet(x = X, y = Y, nlambda = 100, alpha=1, family=family)

  ### Extract filtered variable based on lambda ###
  VI <- coef(mod, s = lambda)[,1]
  VI = VI[-1]
  filter.vars = names(VI[VI!=0])
  return( list(filter.vars=filter.vars) )
}

An option to change lambda, which can change which variables remain after filtering (lambda.min keeps more, lambda.1se keeps less), while not required, is also included. This can then be adjusted through the “filter.hyper” argument in PRISM.

Patient-Level Estimates (ple)

The template ple function is:

model_template(type="ple")
#> function (Y, A, X, Xtest, ...) 
#> {
#> }
#> <bytecode: 0x000000001b0a3d00>
#> <environment: 0x0000000018b69428>

For the “ple” model, the only required arguments are the observed data (Y,A,X) and Xtest. By default, if Xtest is not provided in PRISM, it uses the training X instead. The only required outputs are mod (fitted models(s)) and a prediction function or pre-computed predictions in the training/test set (mu_train, mu_test). However, certain features in PRISM, such as the heat map plots, cannot be utilized without providing a prediction funcion. In the example below, treatment-specific random forest models are fit with hyperparameter “mtry” (number of variables randomly selected at each split). This can be altered in the “ple.hyper” argument in PRISM. Notably, certain default plots or parameter functions require the ple predictions to be named as “mu0”, “mu1”, and “PLE”.

ple_ranger_mtry = function(Y, A, X, Xtest, mtry=5, ...){
   require(ranger)
   ## Split data by treatment ###
    train0 =  data.frame(Y=Y[A==0], X[A==0,])
    train1 =  data.frame(Y=Y[A==1], X[A==1,])
    # Trt 0 #
    mod0 <- ranger(Y ~ ., data = train0, seed=1, mtry = mtry)
    # Trt 1 #
    mod1 <- ranger(Y ~ ., data = train1, seed=2, mtry = mtry)
    mod = list(mod0=mod0, mod1=mod1)
    pred.fun <- function(mod, X){
      mu1_hat <- predict( mod$mod1, X )$predictions
      mu0_hat <- predict( mod$mod0, X )$predictions
      mu_hat <- data.frame(mu1 = mu1_hat, mu0 = mu0_hat, PLE = mu1_hat-mu0_hat)
      return(mu_hat)
      }
    res = list(mod=mod, pred.fun=pred.fun)
    return( res )
}

Subgroup Identification (submod)

The template submod function is:

model_template(type="submod")
#> function (Y, A, X, Xtest, mu_train, ...) 
#> {
#> }
#> <bytecode: 0x000000001b0a3ef8>
#> <environment: 0x000000002070cfa8>

For the “submod” model, the only required arguments are the observed data (Y,A,X) and Xtest. “mu_train” (based on ple predictions) can also be passed through. The only required outputs are mod (fitted models(s)) and a prediction function or pre-computed subgroup predictions in the training/test set (Subgrps.train, Subgrps.test). In the example below, consider a modified version of “submod_lmtree” where we search for predictive effects only. By default, “submod_lmtree” searches for prognostic and/or predictive effects.

submod_lmtree_pred = function(Y, A, X, Xtest, mu_train, ...){
  require(partykit)
  ## Fit Model ##
  mod <- lmtree(Y~A | ., data = X, parm=2) ##parm=2 focuses on treatment interaction #
  pred.fun <- function(mod, X=NULL, type="subgrp"){
     Subgrps <- NULL
     Subgrps <- as.numeric( predict(mod, type="node", newdata = X) )
     return( list(Subgrps=Subgrps) )
  }
  ## Return Results ##
  return(  list(mod=mod, pred.fun=pred.fun) )
}

Parameter Estimation (param)

The template param function is:

model_template(type="param")
#> function (Y, A, X, mu_hat, Subgrps, alpha_ovrl, alpha_s, ...) 
#> {
#> }
#> <bytecode: 0x000000001b0a4160>
#> <environment: 0x000000002c9b54e0>

For the parameter model, the key arguments are (Y, A, X) (observed data), mu_hat (ple predictions), Subgrps, alpha_ovrl and alpha_s (overall and subgroup alpha levels). The only required output is “param.dat”, which contains parameter estimates/variability metrics. For all PRISM functionality to work, param.dat should contain column names of “est” (parameter estimate), “SE” (standard error), and “LCL”/“UCL” (lower and upper confidence limits). It is recommended to include an “estimand” column for labeling purpose. In the example below, M-estimation models are fit for each subgroup and overall. Alternatively, a single M-estimation model could’ve been fit.


### Robust linear Regression: E(Y|A=1) - E(Y|A=0) ###
param_rlm = function(Y, A, X, mu_hat, Subgrps, alpha_ovrl, alpha_s, ...){
  require(MASS)
  indata = data.frame(Y=Y,A=A, X)
  mod.ovrl = rlm(Y ~ A , data=indata)
  param.dat0 = data.frame( Subgrps=0, N = dim(indata)[1],
                           estimand = "E(Y|A=1)-E(Y|A=0)",
                           est = summary(mod.ovrl)$coefficients[2,1],
                           SE = summary(mod.ovrl)$coefficients[2,2] )
  param.dat0$LCL = with(param.dat0, est-qt(1-alpha_ovrl/2, N-1)*SE)
  param.dat0$UCL = with(param.dat0, est+qt(1-alpha_ovrl/2, N-1)*SE)
  param.dat0$pval = with(param.dat0, 2*pt(-abs(est/SE), df=N-1) )

  ## Subgroup Specific Estimate ##
  looper = function(s){
    rlm.mod = tryCatch( rlm(Y ~ A , data=indata[Subgrps==s,]),
                       error = function(e) "param error" )
    n.s = dim(indata[Subgrps==s,])[1]
    est = summary(rlm.mod)$coefficients[2,1]
    SE = summary(rlm.mod)$coefficients[2,2]
    LCL =  est-qt(1-alpha_ovrl/2, n.s-1)*SE
    UCL =  est+qt(1-alpha_ovrl/2, n.s-1)*SE
    pval = 2*pt(-abs(est/SE), df=n.s-1)
    return( c(est, SE, LCL, UCL, pval) )
  }
  S_levels = as.numeric( names(table(Subgrps)) )
  S_N = as.numeric( table(Subgrps) )
  param.dat = lapply(S_levels, looper)
  param.dat = do.call(rbind, param.dat)
  param.dat = data.frame( S = S_levels, N=S_N, estimand="E(Y|A=1)-E(Y|A=0)", param.dat)
  colnames(param.dat) = c("Subgrps", "N", "estimand", "est", "SE", "LCL", "UCL", "pval")
  param.dat = rbind( param.dat0, param.dat)
  return( param.dat )
}

Putting it All Together

Finally, let’s input these user-specific functions into PRISM:


res_user1 = PRISM(Y=Y, A=A, X=X, family="gaussian", filter="filter_lasso", 
             ple = "ple_ranger_mtry", submod = "submod_lmtree_pred",
             param="param_rlm")
#> Warning: package 'glmnet' was built under R version 3.5.3
#> Warning: package 'foreach' was built under R version 3.5.2
#> Warning: package 'ranger' was built under R version 3.5.2
## variables that remain after filtering ##
res_user1$filter.vars
#>  [1] "X1"  "X2"  "X3"  "X5"  "X7"  "X8"  "X10" "X12" "X16" "X18" "X24"
#> [12] "X26" "X31" "X40" "X46" "X50"
## Subgroup model: lmtree searching for predictive only ##
plot(res_user1$submod.fit$mod)

## Parameter estimates/inference
res_user1$param.dat
#>   Subgrps   N          estimand          est         SE         LCL
#> 1       0 800 E(Y|A=1)-E(Y|A=0)  0.229913662 0.08114137  0.07063823
#> 2       2 426 E(Y|A=1)-E(Y|A=0) -0.004165125 0.10352507 -0.20765001
#> 3       3 374 E(Y|A=1)-E(Y|A=0)  0.506836561 0.11525208  0.28021130
#>         UCL         pval alpha
#> 1 0.3891891 4.720217e-03  0.05
#> 2 0.1993198 9.679263e-01  0.05
#> 3 0.7334618 1.429483e-05  0.05
plot(res_user1, type="forest")

Conclusion

Overall, each step of PRISM is customizable, allowing for fast experimentation and improvement of individual steps. The main consideration for customizing the steps are certain required inputs/outputs.