如题所述,下面是最小复现代码
library(xgboost)
library(survival)
library(riskRegression)
library(prodlim)
# 1. Simulate Survival Data
set.seed(123)
n <- 500
dat <- data.frame(
time = rexp(n, rate = 0.2),
status = rbinom(n, 1, 0.7), # 1=event, 0=censored
x1 = rnorm(n),
x2 = rnorm(n)
)
# Split into Train/Test
train_idx <- sample(1:n, n * 0.7)
train_data <- dat[train_idx, ]
test_data <- dat[-train_idx, ]
# 2. Prepare Matrices for XGBoost
dtrain <- xgb.DMatrix(data = as.matrix(train_data[, c("x1", "x2")]),
label = ifelse(train_data$status == 1, train_data$time, -train_data$time)) # Convention for survival:cox
dtest <- xgb.DMatrix(data = as.matrix(test_data[, c("x1", "x2")]))
# 3. Train XGBoost (Cox Objective)
params <- list(
objective = "survival:cox",
eval_metric = "cox-nloglik",
eta = 0.05
)
xgb_model <- xgb.train(params = params, data = dtrain, nrounds = 100)
# 4. Predict Risk Scores (Hazard Ratios)
# Note: XGBoost survival:cox predicts the Hazard Ratio (HR) by default
train_pred_hr <- predict(xgb_model, dtrain)
test_pred_hr <- predict(xgb_model, dtest)
# first method to calculate survival probabilities
# A. Estimate Baseline Hazard using the training predictions
# We create a Cox model where the coefficients are fixed (using offset)
# We take log() because Cox expects linear predictors, but XGBoost gave us HRs (exp)
cox_fit <- coxph(Surv(time, status) ~ offset(log(train_pred_hr)), data = train_data)
# B. Get Baseline S0(t) specifically where HR=1 (Offset=0)
# We provide newdata where the predictor results in log(HR)=0
dummy_data <- data.frame(train_pred_hr = 1.0)
base_surv <- survfit(cox_fit, newdata = dummy_data)
报错信息
> base_surv <- survfit(cox_fit, newdata = dummy_data)
Error in outer(fit$cumhaz, c(x2)) - fit$xbar : 非整合陈列
配置xfun::session_info()
> xfun::session_info()
R version 4.1.1 (2021-08-10)
Platform: x86_64-redhat-linux-gnu (64-bit)
Running under: CentOS Linux 8, RStudio 2021.9.0.351
Locale:
LC_CTYPE=zh_CN.UTF-8 LC_NUMERIC=C LC_TIME=zh_CN.UTF-8 LC_COLLATE=zh_CN.UTF-8 LC_MONETARY=zh_CN.UTF-8
LC_MESSAGES=zh_CN.UTF-8 LC_PAPER=zh_CN.UTF-8 LC_NAME=C LC_ADDRESS=C LC_TELEPHONE=C
LC_MEASUREMENT=zh_CN.UTF-8 LC_IDENTIFICATION=C
Package version:
backports_1.4.1 base64enc_0.1-3 caret_6.0-90 checkmate_2.0.0 class_7.3-19
cli_3.6.5 cluster_2.1.4 cmprsk_2.2-10 codetools_0.2-19 colorspace_2.1-0
compiler_4.1.1 conquer_1.2.1 cpp11_0.4.7 data.table_1.14.8 digest_0.6.33
doParallel_1.0.17 dplyr_1.1.4 e1071_1.7.9 ellipsis_0.3.2 evaluate_0.21
fansi_1.0.4 farver_2.1.1 fastmap_1.1.1 foreach_1.5.1 foreign_0.8-84
Formula_1.2-5 future_1.33.2 future.apply_1.8.1 generics_0.1.3 ggplot2_3.5.0
globals_0.16.3 glue_1.8.0 gower_1.0.1 graphics_4.1.1 grDevices_4.1.1
grid_4.1.1 gridExtra_2.3 gtable_0.3.3 highr_0.11 Hmisc_4.6-0
htmlTable_2.3.0 htmltools_0.5.2 htmlwidgets_1.5.4 ipred_0.9-12 isoband_0.2.7
iterators_1.0.14 jpeg_0.1-10 jsonlite_1.8.7 KernSmooth_2.23.22 knitr_1.49
labeling_0.4.2 lattice_0.21-8 latticeExtra_0.6-29 lava_1.7.2.1 lifecycle_1.0.4
listenv_0.9.0 lubridate_1.9.2 magrittr_2.0.3 MASS_7.3-60 Matrix_1.3-4
MatrixModels_0.5-0 matrixStats_1.0.0 methods_4.1.1 mets_1.3.2 mgcv_1.8.36
ModelMetrics_1.2.2.2 multcomp_1.4-17 munsell_0.5.0 mvtnorm_1.2-2 nlme_3.1-152
nnet_7.3-19 numDeriv_2016.8-1.1 parallel_4.1.1 parallelly_1.36.0 pillar_1.9.0
pkgconfig_2.0.3 plotrix_3.8.2 plyr_1.8.6 png_0.1-8 polspline_1.1.23
pROC_1.18.0 prodlim_2019.11.13 progressr_0.9.0 proxy_0.4.27 Publish_2023.1.17
purrr_1.0.1 quantreg_5.86 R6_2.5.1 ranger_0.16.0 RColorBrewer_1.1-2
Rcpp_1.0.7 RcppArmadillo_0.10.7.0.0 RcppEigen_0.3.3.9.1 recipes_0.1.17 reshape2_1.4.4
riskRegression_2023.09.08 rlang_1.1.6 rms_6.2-0 rpart_4.1-15 rstudioapi_0.14
sandwich_3.0-1 scales_1.3.0 SparseM_1.81 splines_4.1.1 SQUAREM_2021.1
stats_4.1.1 stats4_4.1.1 stringi_1.7.5 stringr_1.5.0 survival_3.2-13
TH.data_1.1-0 tibble_3.2.1 tidyr_1.3.2 tidyselect_1.2.1 timechange_0.2.0
timeDate_3043.102 timereg_2.0.5 tools_4.1.1 utf8_1.2.2 utils_4.1.1
vctrs_0.6.5 viridis_0.6.2 viridisLite_0.4.0 withr_2.5.0 xfun_0.50
xgboost_1.7.11.1 yaml_2.2.1 zoo_1.8-9