This article was first published on
T. Moudiki's Webpage - Python , and kindly contributed to python-bloggers. (You can report issue about the content on this page here)
Want to share your content on python-bloggers? click here.
Want to share your content on python-bloggers? click here.
The model presented here is a frequentist – conformalized – version of the Bayesian one presented last week in #152. The model is implemented in learningmachine
, both in Python and R. Model explanations are given as sensitivity analyses.
0 – install packages
For R
utils::install.packages(c("rmarkdown", "reticulate", "remotes"))
Installing packages into '/cloud/lib/x86_64-pc-linux-gnu-library/4.4' (as 'lib' is unspecified)
remotes::install_github("thierrymoudiki/bayesianrvfl")
Skipping install of 'bayesianrvfl' from a github remote, the SHA1 (a8e9e78a) has not changed since last install. Use `force = TRUE` to force installationSkipping install of 'learningmachine' from a github remote, the SHA1 (6b930284) has not changed since last install. Use `force = TRUE` to force installation
library("learningmachine")
Loading required package: randtoolboxLoading required package: rngWELLThis is randtoolbox. For an overview, type 'help("randtoolbox")'.Loading required package: tseriesRegistered S3 method overwritten by 'quantmod': method from as.zoo.data.frame zoo Loading required package: memoiseLoading required package: foreachLoading required package: skimrLoading required package: snowLoading required package: doSNOWLoading required package: iterators
For Python
pip install matplotlib
Defaulting to user installation because normal site-packages is not writeable Requirement already satisfied: matplotlib in /cloud/python/lib/python3.8/site-packages (3.7.5) Requirement already satisfied: pyparsing>=2.3.1 in /cloud/python/lib/python3.8/site-packages (from matplotlib) (3.1.2) Requirement already satisfied: packaging>=20.0 in /cloud/python/lib/python3.8/site-packages (from matplotlib) (24.1) Requirement already satisfied: kiwisolver>=1.0.1 in /cloud/python/lib/python3.8/site-packages (from matplotlib) (1.4.5) Requirement already satisfied: importlib-resources>=3.2.0 in /cloud/python/lib/python3.8/site-packages (from matplotlib) (6.4.3) Requirement already satisfied: numpy<2,>=1.20 in /cloud/python/lib/python3.8/site-packages (from matplotlib) (1.24.4) Requirement already satisfied: pillow>=6.2.0 in /cloud/python/lib/python3.8/site-packages (from matplotlib) (10.4.0) Requirement already satisfied: cycler>=0.10 in /cloud/python/lib/python3.8/site-packages (from matplotlib) (0.12.1) Requirement already satisfied: python-dateutil>=2.7 in /cloud/python/lib/python3.8/site-packages (from matplotlib) (2.9.0.post0) Requirement already satisfied: fonttools>=4.22.0 in /cloud/python/lib/python3.8/site-packages (from matplotlib) (4.53.1) Requirement already satisfied: contourpy>=1.0.1 in /cloud/python/lib/python3.8/site-packages (from matplotlib) (1.1.1) Requirement already satisfied: zipp>=3.1.0 in /cloud/python/lib/python3.8/site-packages (from importlib-resources>=3.2.0->matplotlib) (3.20.0) Requirement already satisfied: six>=1.5 in /cloud/python/lib/python3.8/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0) [notice] A new release of pip is available: 23.0.1 -> 24.2 [notice] To update, run: /opt/python/3.8.17/bin/python3.8 -m pip install --upgrade pip
pip install git+https://github.com/Techtonique/learningmachine_python.git
Defaulting to user installation because normal site-packages is not writeable Collecting git+https://github.com/Techtonique/learningmachine_python.git Cloning https://github.com/Techtonique/learningmachine_python.git to /tmp/pip-req-build-37h1oa6g Running command git clone --filter=blob:none --quiet https://github.com/Techtonique/learningmachine_python.git /tmp/pip-req-build-37h1oa6g Resolved https://github.com/Techtonique/learningmachine_python.git to commit 3ec7ca96df71add6218d55db9ef5d8eb40275877 Preparing metadata (setup.py): started Preparing metadata (setup.py): finished with status 'done' Requirement already satisfied: numpy in /cloud/python/lib/python3.8/site-packages (from learningmachine==2.2.2) (1.24.4) Requirement already satisfied: pandas in /cloud/python/lib/python3.8/site-packages (from learningmachine==2.2.2) (2.0.3) Requirement already satisfied: rpy2>=3.4.5 in /cloud/python/lib/python3.8/site-packages (from learningmachine==2.2.2) (3.5.16) Requirement already satisfied: scikit-learn in /cloud/python/lib/python3.8/site-packages (from learningmachine==2.2.2) (1.3.2) Requirement already satisfied: scipy in /cloud/python/lib/python3.8/site-packages (from learningmachine==2.2.2) (1.10.1) Requirement already satisfied: jinja2 in /cloud/python/lib/python3.8/site-packages (from rpy2>=3.4.5->learningmachine==2.2.2) (3.1.4) Requirement already satisfied: tzlocal in /cloud/python/lib/python3.8/site-packages (from rpy2>=3.4.5->learningmachine==2.2.2) (5.2) Requirement already satisfied: backports.zoneinfo in /cloud/python/lib/python3.8/site-packages (from rpy2>=3.4.5->learningmachine==2.2.2) (0.2.1) Requirement already satisfied: cffi>=1.15.1 in /cloud/python/lib/python3.8/site-packages (from rpy2>=3.4.5->learningmachine==2.2.2) (1.17.0) Requirement already satisfied: tzdata>=2022.1 in /cloud/python/lib/python3.8/site-packages (from pandas->learningmachine==2.2.2) (2024.1) Requirement already satisfied: python-dateutil>=2.8.2 in /cloud/python/lib/python3.8/site-packages (from pandas->learningmachine==2.2.2) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /cloud/python/lib/python3.8/site-packages (from pandas->learningmachine==2.2.2) (2024.1) Requirement already satisfied: threadpoolctl>=2.0.0 in /cloud/python/lib/python3.8/site-packages (from scikit-learn->learningmachine==2.2.2) (3.5.0) Requirement already satisfied: joblib>=1.1.1 in /cloud/python/lib/python3.8/site-packages (from scikit-learn->learningmachine==2.2.2) (1.4.2) Requirement already satisfied: pycparser in /cloud/python/lib/python3.8/site-packages (from cffi>=1.15.1->rpy2>=3.4.5->learningmachine==2.2.2) (2.22) Requirement already satisfied: six>=1.5 in /cloud/python/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas->learningmachine==2.2.2) (1.16.0) Requirement already satisfied: MarkupSafe>=2.0 in /cloud/python/lib/python3.8/site-packages (from jinja2->rpy2>=3.4.5->learningmachine==2.2.2) (2.1.5) [notice] A new release of pip is available: 23.0.1 -> 24.2 [notice] To update, run: /opt/python/3.8.17/bin/python3.8 -m pip install --upgrade pip
1 – Python code version
import packages
import numpy as np import warnings import learningmachine as lm
/cloud/python/lib/python3.8/site-packages/rpy2/rinterface_lib/embedded.py:276: UserWarning: R was initialized outside of rpy2 (R_NilValue != NULL). Trying to use it nevertheless. warnings.warn(msg) R was initialized outside of rpy2 (R_NilValue != NULL). Trying to use it nevertheless.
plotting function
warnings.filterwarnings('ignore') split_color = 'green' split_color2 = 'orange' local_color = 'gray' def plot_func(x, y, y_u=None, y_l=None, pred=None, shade_color=split_color, method_name="", title=""): fig = plt.figure() plt.plot(x, y, 'k.', alpha=.3, markersize=10, fillstyle='full', label=u'Test set observations') if (y_u is not None) and (y_l is not None): plt.fill(np.concatenate([x, x[::-1]]), np.concatenate([y_u, y_l[::-1]]), alpha=.3, fc=shade_color, ec='None', label = method_name + ' Prediction interval') if pred is not None: plt.plot(x, pred, 'k--', lw=2, alpha=0.9, label=u'Predicted value') #plt.ylim([-2.5, 7]) plt.xlabel('$X$') plt.ylabel('$Y$') plt.legend(loc='upper right') plt.title(title) plt.show()
fit_obj = lm.Regressor(method="rvfl", pi_method = "kdejackknifeplus", nb_hidden = 3) data = fetch_california_housing() X, y = data.data[:600], data.target[:600] X = pd.DataFrame(X, columns=data.feature_names) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=123) start = time() fit_obj.fit(X_train, y_train, reg_lambda=1)
Regressor(method='rvfl', nb_hidden=3, pi_method='kdejackknifeplus')Elapsed time: 11.298886775970459preds: DescribeResult(preds=array([2.49357776, 1.43407668, 1.5975387 , 1.12443032, 2.82375138, 2.4580493 , 3.54353746, 1.8909777 , 2.05285337, 3.77584837, 1.83513261, 1.07393557, 1.94194497, 0.92784638, 1.05871262, 2.22243527, 1.65294959, 1.91187611, 1.77770079, 1.25768064, 1.87622468, 1.74061527, 2.92453382, 1.66174853, 1.6614039 , 1.77202212, 5.07081658, 2.38262843, 2.53746616, 1.363203 , 3.20583396, 1.86606062, 1.70192252, 1.07269937, 1.31627684, 2.03799042, 3.26290966, 3.97461689, 1.13406842, 2.47295003, 1.5368609 , 1.09332542, 1.4092314 , 2.60417043, 5.89174275, 1.50619443, 2.4628304 , 2.19871466, 2.1841939 , 1.48221595, 1.87870123, 0.77989849, 1.44449574, 2.38624907, 1.29244221, 0.78829802, 1.41881675, 1.74509957, 1.11005633, 4.01926467, 2.05504522, 1.71159601, 3.2461924 , 2.83470994, 2.19983398, 1.59497106, 1.70886701, 1.90077656, 1.45410688, 1.66886891, 1.75110881, 3.05357402, 3.6834588 , 1.08916238, 1.84916719, 1.80851509, 2.15596533, 1.97719052, 1.92875913, 3.38678466, 1.32564179, 1.81436977, 1.63403928, 1.18568231, 1.7315052 , 0.8090071 , 2.47389329, 1.19407051, 2.23409297, 2.30256486, 5.37829418, 1.38890835, 2.00637569, 1.99204335, 1.85314134, 2.86190699, 1.91334581, 1.24205971, 3.20558421, 4.26259962, 2.01732404, 3.95679516, 1.39072672, 1.65839735, 2.30550471, 2.3517598 , 2.13887703, 1.25020289, 1.88503591, 1.16169133, 2.00430928, 3.26703677, 1.3242707 , 2.01948611, 2.43700375, 1.69810225, 1.69708568, 2.18234363, 2.12271321, 0.94102412]), sims=array([[2.52197294, 2.65764086, 2.40986397, ..., 2.78250935, 2.00710776, 2.45581687], [1.73158936, 1.64068277, 1.30405511, ..., 1.5192879 , 1.7531569 , 1.96604158], [1.9742429 , 1.39726812, 1.32734606, ..., 1.22639871, 1.46987265, 1.40392566], ..., [2.0119616 , 2.40848185, 2.09426227, ..., 2.62052065, 2.51099616, 2.36614074], [1.35463496, 2.3932347 , 2.05260157, ..., 2.40050845, 1.99334353, 1.87698499], [0.48279083, 0.94856269, 0.97790733, ..., 1.0927788 , 1.23311912, 0.40108719]]), lower=array([ 1.93689669e+00, 6.14177968e-01, 1.04195634e+00, 3.75503359e-01, 2.02144872e+00, 1.82276306e+00, 2.64944272e+00, 1.34667642e+00, 1.29086670e+00, 3.05297704e+00, 1.17765182e+00, 4.30461766e-01, 1.37895403e+00, 2.98779763e-01, 4.40822194e-01, 1.55654853e+00, 9.38628380e-01, 1.21829098e+00, 9.84558834e-01, 6.49532691e-01, 9.01772850e-01, 1.19276603e+00, 2.17680551e+00, 1.09038146e+00, 1.16893453e+00, 9.31502331e-01, 4.18264443e+00, 1.49607918e+00, 1.57691294e+00, 3.28849227e-01, 2.39003022e+00, 1.07840109e+00, 8.10677053e-01, 5.57945108e-01, 5.89483726e-01, 1.40646500e+00, 2.38769699e+00, 3.33592642e+00, 4.92576112e-01, 1.64897773e+00, 8.70602613e-01, 3.28203503e-01, 5.98006128e-01, 1.85519350e+00, 5.28848929e+00, 8.98852307e-01, 1.62711360e+00, 1.55280346e+00, 1.49136612e+00, 8.73686102e-01, 1.36286310e+00, -1.81572146e-01, 7.77535746e-01, 1.55053396e+00, 7.13006914e-01, 4.64261897e-03, 8.01311589e-01, 7.48462573e-01, 4.37894150e-01, 3.20953201e+00, 1.01569966e+00, 1.04846962e+00, 2.60967933e+00, 1.88654294e+00, 1.14477050e+00, 1.14136267e+00, 1.18280741e+00, 1.25494957e+00, 8.98454204e-01, 2.10925704e-01, 1.05478650e+00, 2.35858707e+00, 2.94164760e+00, 1.67571043e-01, 1.33878742e+00, 1.25839350e+00, 1.53951656e+00, 9.05624320e-01, 1.17061173e+00, 2.54655499e+00, 6.47466317e-01, 1.20289748e+00, 1.13571370e+00, 4.64266018e-01, 1.14215867e+00, 1.49252246e-01, 1.82103747e+00, -2.49955629e-01, 1.61755751e+00, 1.65686043e+00, 4.78967402e+00, 4.81742186e-01, 1.44596520e+00, 1.32448131e+00, 8.89972080e-01, 2.34075771e+00, 1.24303256e+00, 7.03972518e-01, 2.60285187e+00, 3.53153738e+00, 1.35196678e+00, 3.50221519e+00, 4.47336205e-01, 9.67171390e-01, 1.76976985e+00, 1.59672116e+00, 1.28269514e+00, 6.40395915e-01, 1.29953005e+00, 5.27235568e-01, 1.46471898e+00, 1.66067091e+00, 4.41349172e-01, 1.39726429e+00, 1.80795176e+00, 1.07103425e+00, 5.76820176e-01, 1.21573235e+00, 1.34041046e+00, 1.73301081e-01]), upper=array([2.99397933, 1.96132754, 1.9741853 , 1.56915801, 3.40435554, 2.91464536, 4.17157769, 2.35189282, 2.54248989, 4.34702931, 2.28681813, 1.58798973, 2.46838793, 1.4842451 , 1.52209298, 2.75057218, 2.25719594, 2.41607558, 2.30447366, 1.71324602, 2.33541568, 2.24194145, 3.37712022, 2.16161502, 2.16743013, 2.32044303, 5.7525094 , 2.95770978, 3.05001003, 1.96251894, 3.65612204, 2.41472665, 2.15530985, 1.58579045, 1.84162227, 2.53564354, 3.82321955, 4.44042528, 1.60388757, 2.96240096, 2.12835682, 1.5469312 , 1.86880196, 3.16319094, 6.30956862, 1.96174807, 2.99380439, 2.88775395, 2.60191973, 1.93653757, 2.19505169, 1.30622414, 1.85682241, 2.95891228, 1.79928438, 1.29625491, 2.09220303, 2.36290631, 1.53658861, 4.7506125 , 2.67337766, 2.17030365, 3.75023808, 3.35541687, 2.77520135, 2.08450402, 2.16371497, 2.38837026, 1.87832477, 2.16419849, 2.36599111, 3.50414737, 4.29523451, 1.53640485, 2.2309006 , 2.33145818, 2.66451633, 2.49759326, 2.43879385, 3.90232458, 1.84661217, 2.27974655, 2.06887939, 1.71836616, 2.15279379, 1.31422259, 2.90745441, 1.76813 , 2.71020272, 2.81272411, 5.92280218, 1.86562672, 2.5298048 , 2.51531422, 2.39877546, 3.36143692, 2.46371693, 1.76912589, 3.67435993, 4.72940884, 2.48373427, 4.37798834, 2.15801521, 2.11703737, 3.27506479, 2.88194859, 2.6982464 , 1.80432322, 2.45259713, 1.66864527, 2.49172073, 3.85706896, 1.95383092, 2.49247639, 2.9572688 , 2.24464142, 2.27349167, 2.61578328, 2.64429111, 1.49723436]))coverage rate: 0.8
plot_func(x = range(len(y_test)), y = y_test, y_u=preds.upper, y_l=preds.lower, pred=preds.preds, method_name="before update", title="")
# update fit_obj.update(X_test.iloc[0,:], y_test[0])
Regressor(method='rvfl', nb_hidden=3, pi_method='kdejackknifeplus')coverage rate: 0.8067226890756303
plot_func(x = range(len(y_test[:-1])), y = y_test[:-1], y_u=preds.upper, y_l=preds.lower, pred=preds.preds, method_name="after update", title="")
2 – R code version
X <- as.matrix(mtcars[,-1]) y <- mtcars$mpg set.seed(123) (index_train <- base::sample.int(n = nrow(X), size = floor(0.7*nrow(X)), replace = FALSE))
[1] 31 15 19 14 3 10 18 22 11 5 20 29 23 30 9 28 8 27 7 32 26 17[1] 22 10[1] 10 10
obj <- learningmachine::Regressor$new(method = "rvfl", nb_hidden = 50L, pi_method = "splitconformal") obj$get_type()
[1] "regression"[1] "Regressor"
t0 <- proc.time()[3] obj$fit(X_train, y_train, reg_lambda = 0.01) cat("Elapsed: ", proc.time()[3] - t0, "s n")
Elapsed: 0.01 s n
print(obj$predict(X_test))
$preds Mazda RX4 Mazda RX4 Wag Hornet 4 Drive Valiant 21.350888 19.789387 13.106761 9.695310 Merc 450SE Merc 450SL Lincoln Continental Toyota Corona 11.131161 12.568682 2.044672 19.289805 Camaro Z28 Pontiac Firebird 14.847878 12.282272 $lower Mazda RX4 Mazda RX4 Wag Hornet 4 Drive Valiant 12.3508879 10.7893873 4.1067608 0.6953102 Merc 450SE Merc 450SL Lincoln Continental Toyota Corona 2.1311611 3.5686817 -6.9553279 10.2898053 Camaro Z28 Pontiac Firebird 5.8478777 3.2822719 $upper Mazda RX4 Mazda RX4 Wag Hornet 4 Drive Valiant 30.35089 28.78939 22.10676 18.69531 Merc 450SE Merc 450SL Lincoln Continental Toyota Corona 20.13116 21.56868 11.04467 28.28981 Camaro Z28 Pontiac Firebird 23.84788 21.28227
obj$summary(X_test, y=y_test, show_progress=FALSE)
$R_squared [1] -1.505856 $R_squared_adj [1] 23.55271 $Residuals Min. 1st Qu. Median Mean 3rd Qu. Max. -1.548 1.461 5.000 4.349 7.949 8.405 $Coverage_rate [1] 100 $ttests estimate lower upper p-value signif cyl 137.649985 39.777048 235.5229227 1.115728e-02 * disp -2.406399 -4.650678 -0.1621204 3.825959e-02 * hp -0.527573 -1.402043 0.3468975 2.054686e-01 drat 707.372951 246.095138 1168.6507638 7.059500e-03 ** wt -500.429007 -565.047979 -435.8100352 2.910469e-08 *** qsec -89.930939 -124.899691 -54.9621860 2.537870e-04 *** vs 234.198406 -127.886990 596.2838006 1.774484e-01 am -235.789718 -512.422513 40.8430776 8.592503e-02 . gear 52.646721 -6.640614 111.9340567 7.547657e-02 . carb -17.100561 -87.819649 53.6185270 5.976705e-01 $effects ── Data Summary ──────────────────────── Values Name effects Number of rows 10 Number of columns 10 _______________________ Column type frequency: numeric 10 ________________________ Group variables None ── Variable type: numeric ────────────────────────────────────────────────────── skim_variable mean sd p0 p25 p50 p75 p100 1 cyl 138. 137. -8.40 75.8 91.1 98.6 394. 2 disp -2.41 3.14 -8.46 -1.32 -1.08 -0.775 -0.300 3 hp -0.528 1.22 -3.40 -0.695 -0.188 0.0137 0.893 4 drat 707. 645. 55.7 388. 482. 563. 1939. 5 wt -500. 90.3 -698. -538. -500. -458. -377. 6 qsec -89.9 48.9 -145. -128. -102. -64.0 2.67 7 vs 234. 506. -121. -13.2 36.8 53.2 1269. 8 am -236. 387. -653. -450. -397. -168. 519. 9 gear 52.6 82.9 -107. -4.69 66.2 112. 170. 10 carb -17.1 98.9 -117. -64.6 -60.6 -17.5 171. hist 1 ▂▇▁▁▂ 2 ▂▁▁▁▇ 3 ▁▁▁▇▂ 4 ▅▇▁▁▃ 5 ▂▁▆▇▃ 6 ▇▆▁▂▃ 7 ▇▁▁▁▂ 8 ▆▇▂▁▃ 9 ▂▅▅▅▇ 10 ▇▂▁▁▂
t0 <- proc.time()[3] obj$fit(X_train, y_train) cat("Elapsed: ", proc.time()[3] - t0, "s n")
Elapsed: 0.128 s n
[1] 1
update RVFL model
previous_coefs <- drop(obj$model$coef)
newx <- X_test[1, ] newy <- y_test[1] new_X_test <- X_test[-1, ] new_y_test <- y_test[-1] t0 <- proc.time()[3] obj$update(newx, newy) cat("Elapsed: ", proc.time()[3] - t0, "s n")
Elapsed: 0.242 s n
summary(previous_coefs)
Min. 1st Qu. Median Mean 3rd Qu. Max. -0.68212 -0.26567 -0.05157 0.00700 0.21046 2.19222 Min. 1st Qu. Median Mean 3rd Qu. Max. -0.030666 -0.002610 0.004189 0.002917 0.011386 0.025243
obj$summary(new_X_test, y=new_y_test, show_progress=FALSE)
$R_squared [1] -1.809339 $R_squared_adj [1] 12.23735 $Residuals Min. 1st Qu. Median Mean 3rd Qu. Max. -1.168 2.513 5.541 5.058 8.185 8.703 $Coverage_rate [1] 100 $ttests estimate lower upper p-value signif cyl 111.6701473 17.076928 206.2633669 2.615518e-02 * disp -1.7983224 -3.876380 0.2797349 8.106884e-02 . hp -0.4167545 -1.501658 0.6681495 4.015523e-01 drat 569.9102780 148.862037 990.9585186 1.420088e-02 * wt -504.1496696 -583.757006 -424.5423330 4.741273e-07 *** qsec -107.9102921 -138.571336 -77.2492482 3.936777e-05 *** vs 145.0280002 -173.164419 463.2204193 3.239468e-01 am -319.6910568 -566.618653 -72.7634604 1.745263e-02 * gear 57.7630332 -18.934712 134.4607782 1.206459e-01 carb -42.9572292 -108.690903 22.7764447 1.702409e-01 $effects ── Data Summary ──────────────────────── Values Name effects Number of rows 9 Number of columns 10 _______________________ Column type frequency: numeric 10 ________________________ Group variables None ── Variable type: numeric ────────────────────────────────────────────────────── skim_variable mean sd p0 p25 p50 p75 p100 1 cyl 112. 123. -13.5 64.5 93.6 93.9 426. 2 disp -1.80 2.70 -8.94 -1.41 -0.805 -0.689 -0.361 3 hp -0.417 1.41 -3.54 -0.679 -0.0942 -0.0556 1.19 4 drat 570. 548. 36.8 371. 439. 501. 1972. 5 wt -504. 104. -742. -523. -497. -461. -382. 6 qsec -108. 39.9 -152. -143. -115. -93.0 -35.9 7 vs 145. 414. -116. -23.9 51.1 81.2 1231. 8 am -320. 321. -575. -479. -395. -368. 465. 9 gear 57.8 99.8 -113. 1.22 35.2 130. 196. 10 carb -43.0 85.5 -129. -79.6 -77.9 -22.5 165. hist 1 ▅▇▁▁▂ 2 ▁▁▁▁▇ 3 ▂▁▂▇▃ 4 ▅▇▁▁▂ 5 ▂▁▂▇▃ 6 ▇▅▅▂▂ 7 ▇▁▁▁▁ 8 ▇▁▁▁▁ 9 ▃▇▇▇▇ 10 ▇▅▁▁▂
res <- obj$predict(X = new_X_test) new_y_train <- c(y_train, newy) plot(c(new_y_train, res$preds), type='l', main="", ylab="", ylim = c(min(c(res$upper, res$lower, y)), max(c(res$upper, res$lower, y)))) lines(c(new_y_train, res$upper), col="gray60") lines(c(new_y_train, res$lower), col="gray60") lines(c(new_y_train, res$preds), col = "red") lines(c(new_y_train, new_y_test), col = "blue") abline(v = length(y_train), lty=2, col="black")
[1] 1
update RVFL model (Pt.2)
newx <- X_test[2, ] newy <- y_test[2] new_X_test <- X_test[-c(1, 2), ] new_y_test <- y_test[-c(1, 2)] t0 <- proc.time()[3] obj$update(newx, newy) cat("Elapsed: ", proc.time()[3] - t0, "s n")
Elapsed: 0.077 s n
obj$summary(new_X_test, y=new_y_test, show_progress=FALSE)
$R_squared [1] -3.356623 $R_squared_adj [1] 11.16545 $Residuals Min. 1st Qu. Median Mean 3rd Qu. Max. -1.950 5.030 6.374 6.369 8.774 11.528 $Coverage_rate [1] 75 $ttests estimate lower upper p-value signif cyl 40.8981137 6.878148 74.9180798 2.494779e-02 * disp -0.7335494 -1.206939 -0.2601595 8.026181e-03 ** hp -0.8233606 -2.198927 0.5522055 1.998737e-01 drat 549.7206897 416.053783 683.3875968 2.570765e-05 *** wt -469.9351032 -535.877454 -403.9927527 6.344763e-07 *** qsec -116.6183871 -156.767393 -76.4693814 2.380078e-04 *** vs -194.4213942 -288.046178 -100.7966103 1.732503e-03 ** am -395.7216847 -562.762331 -228.6810387 8.143911e-04 *** gear 53.0732573 -59.833653 165.9801679 3.030574e-01 carb -25.9448064 -63.759959 11.8703467 1.487567e-01 $effects ── Data Summary ──────────────────────── Values Name effects Number of rows 8 Number of columns 10 _______________________ Column type frequency: numeric 10 ________________________ Group variables None ── Variable type: numeric ────────────────────────────────────────────────────── skim_variable mean sd p0 p25 p50 p75 p100 1 cyl 40.9 40.7 -40.5 23.9 56.3 69.9 77.8 2 disp -0.734 0.566 -1.64 -1.03 -0.571 -0.372 -0.139 3 hp -0.823 1.65 -3.99 -1.18 -0.974 -0.196 1.25 4 drat 550. 160. 170. 549. 606. 642. 643. 5 wt -470. 78.9 -543. -537. -489. -437. -336. 6 qsec -117. 48.0 -179. -143. -131. -99.1 -29.9 7 vs -194. 112. -377. -283. -162. -120. -46.3 8 am -396. 200. -719. -481. -357. -319. -67.7 9 gear 53.1 135. -143. -23.9 16.5 172. 231. 10 carb -25.9 45.2 -101. -48.8 -23.8 -9.36 45.7 hist 1 ▂▂▂▁▇ 2 ▅▁▂▇▅ 3 ▂▁▇▂▃ 4 ▁▁▁▁▇ 5 ▇▅▂▁▅ 6 ▂▇▂▂▂ 7 ▂▅▂▇▂ 8 ▃▁▇▂▂ 9 ▂▅▅▁▇ 10 ▂▅▇▁▅
res <- obj$predict(X = new_X_test) new_y_train <- c(y_train, y_test[c(1, 2)]) plot(c(new_y_train, res$preds), type='l', main="", ylab="", ylim = c(min(c(res$upper, res$lower, y)), max(c(res$upper, res$lower, y)))) lines(c(new_y_train, res$upper), col="gray60") lines(c(new_y_train, res$lower), col="gray60") lines(c(new_y_train, res$preds), col = "red") lines(c(new_y_train, new_y_test), col = "blue") abline(v = length(y_train), lty=2, col="black")
[1] 0.75
To leave a comment for the author, please follow the link and comment on their blog: T. Moudiki's Webpage - Python .
Want to share your content on python-bloggers? click here.