KNN Demonstration
patients <- read.csv("breast-cancer.csv") %>% clean_names() %>% mutate(class = factor(class))
glimpse(patients)
## Rows: 683
## Columns: 11
## $ id <int> 1000025, 1002945, 1015425, 1016277, 101702…
## $ clump_thickness <int> 5, 5, 3, 6, 4, 8, 1, 2, 2, 4, 1, 2, 5, 1, …
## $ uniformity_of_cell_size <int> 1, 4, 1, 8, 1, 10, 1, 1, 1, 2, 1, 1, 3, 1,…
## $ uniformity_of_cell_shape <int> 1, 4, 1, 8, 1, 10, 1, 2, 1, 1, 1, 1, 3, 1,…
## $ marginal_adhesion <int> 1, 5, 1, 1, 3, 8, 1, 1, 1, 1, 1, 1, 3, 1, …
## $ single_epithelial_cell_size <int> 2, 7, 2, 3, 2, 7, 2, 2, 2, 2, 1, 2, 2, 2, …
## $ bare_nuclei <int> 1, 10, 2, 4, 1, 10, 10, 1, 1, 1, 1, 1, 3, …
## $ bland_chromatin <int> 3, 3, 3, 3, 3, 9, 3, 3, 1, 2, 3, 2, 4, 3, …
## $ normal_nucleoli <int> 1, 2, 1, 7, 1, 7, 1, 1, 1, 1, 1, 1, 4, 1, …
## $ mitoses <int> 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 1, 1, 1, 1, …
## $ class <fct> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, …
## # A tibble: 2 × 2
## class n
## <fct> <int>
## 1 0 444
## 2 1 239
Name | patients |
Number of rows | 683 |
Number of columns | 11 |
_______________________ | |
Column type frequency: | |
factor | 1 |
numeric | 10 |
________________________ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
class | 0 | 1 | FALSE | 2 | 0: 444, 1: 239 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
id | 0 | 1 | 1076720.23 | 620644.05 | 63375 | 877617 | 1171795 | 1238705 | 13454352 | ▇▁▁▁▁ |
clump_thickness | 0 | 1 | 4.44 | 2.82 | 1 | 2 | 4 | 6 | 10 | ▇▇▇▃▃ |
uniformity_of_cell_size | 0 | 1 | 3.15 | 3.07 | 1 | 1 | 1 | 5 | 10 | ▇▂▁▁▂ |
uniformity_of_cell_shape | 0 | 1 | 3.22 | 2.99 | 1 | 1 | 1 | 5 | 10 | ▇▂▁▁▁ |
marginal_adhesion | 0 | 1 | 2.83 | 2.86 | 1 | 1 | 1 | 4 | 10 | ▇▂▁▁▁ |
single_epithelial_cell_size | 0 | 1 | 3.23 | 2.22 | 1 | 2 | 2 | 4 | 10 | ▇▂▂▁▁ |
bare_nuclei | 0 | 1 | 3.54 | 3.64 | 1 | 1 | 1 | 6 | 10 | ▇▁▁▁▂ |
bland_chromatin | 0 | 1 | 3.45 | 2.45 | 1 | 2 | 3 | 5 | 10 | ▇▅▁▂▁ |
normal_nucleoli | 0 | 1 | 2.87 | 3.05 | 1 | 1 | 1 | 4 | 10 | ▇▁▁▁▁ |
mitoses | 0 | 1 | 1.60 | 1.73 | 1 | 1 | 1 | 1 | 10 | ▇▁▁▁▁ |
ggplot(data=patients, aes(x=bland_chromatin, y=single_epithelial_cell_size, color=class)) +
geom_point(position="jitter")
These two quantities look nicely separated, and could be useful for
prediction! Let’s use them to create a k
Nearest Neighbors
Model.
What does KNN do?
Suppose we are diagnosing a new patient, and we get readings on
bland_chromatin
and
single_epithelial_cell_size
, say 3 and 5, respectively.
patients_split <- initial_split(patients, prop = 0.80, strata = class)
patients_train <- training(patients_split)
patients_test <- testing(patients_split)
patients_train %>%
mutate(
dist = sqrt((bland_chromatin-3)^2+(single_epithelial_cell_size-5)^2)
) %>%
slice_min(dist, n=5, with_ties=TRUE) # There are a bunch of ties!
## id clump_thickness uniformity_of_cell_size uniformity_of_cell_shape
## 1 242970 5 7 7
## 2 718641 1 1 1
## 3 1116132 6 3 4
## 4 1171845 8 6 4
## 5 832226 3 4 4
## marginal_adhesion single_epithelial_cell_size bare_nuclei bland_chromatin
## 1 1 5 8 3
## 2 1 5 1 3
## 3 1 5 2 3
## 4 3 5 9 3
## 5 10 5 1 3
## normal_nucleoli mitoses class dist
## 1 4 1 0 0
## 2 1 1 0 0
## 3 9 1 1 0
## 4 1 1 1 0
## 5 3 1 1 0
What class should we predict here? What if we change the number of neighbors? What if we change the point?
Building a KNN Model
Specify the model:
knn_model <- nearest_neighbor(weight_func = "rectangular", neighbors = 3) %>%
set_engine("kknn") %>%
set_mode("classification")
Fit the model to the training data. You will need to install the
kknn
package.
knn_fit <- knn_model %>%
fit(class ~ bland_chromatin + single_epithelial_cell_size, data = patients_train)
Evaluate the model on the test set (recall that we may also use
augment
here)
# Some performance metrics for classification
conf_mat(patients_pred, truth = class, estimate = .pred_class)
## Truth
## Prediction 0 1
## 0 85 9
## 1 4 39
my_metrics <- metric_set(sens, spec, accuracy)
my_metrics(patients_pred, truth = class, estimate = .pred_class)
## # A tibble: 3 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 sens binary 0.955
## 2 spec binary 0.812
## 3 accuracy binary 0.905
patients_pred <- patients_pred %>% mutate(misclassified = if_else(class != .pred_class, TRUE, FALSE))
ggplot(data=patients_pred, aes(x=bland_chromatin, y=single_epithelial_cell_size)) +
geom_point(
data=patients_train,
aes(x=bland_chromatin, y=single_epithelial_cell_size, shape=class),
position="jitter")+
geom_point(
aes(color=misclassified, shape=class),
position="jitter")
## # A tibble: 6 × 15
## .pred_class .pred_0 .pred_1 id clump_thickness uniformity_of_cell_size
## <fct> <dbl> <dbl> <int> <int> <int>
## 1 1 0.333 0.667 1002945 5 4
## 2 0 0.667 0.333 1041801 5 3
## 3 0 0.667 0.333 1044572 8 7
## 4 0 1 0 1108370 9 5
## 5 0 1 0 1112209 8 10
## 6 0 1 0 1169049 7 3
## # ℹ 9 more variables: uniformity_of_cell_shape <int>, marginal_adhesion <int>,
## # single_epithelial_cell_size <int>, bare_nuclei <int>,
## # bland_chromatin <int>, normal_nucleoli <int>, mitoses <int>, class <fct>,
## # misclassified <lgl>
How many neighbors should we use?
knn_fit <- knn_model %>%
# change this to whatever you want and compare with your neighbors
set_args(neighbors=5) %>%
# Can we try fitting on all the predictors?
fit(class ~ . -id, data = patients_train)
patients_pred <- augment(knn_fit, new_data=patients_test)
#conf_mat(patients_pred, truth = class, estimate = .pred_class)
my_metrics(patients_pred, truth = class, estimate = .pred_class)
## # A tibble: 3 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 sens binary 0.978
## 2 spec binary 0.979
## 3 accuracy binary 0.978
Weighted KNN
Should all the nearest neighbors get an equal vote? Shouldn’t the closest neighbors have a bigger say? This information can be incorporated using weight functions, which give more weight to closer observations and less to ones further away. See this paper for more details.
knn_recipe <- recipe(class ~ ., data=patients_train) %>%
update_role(id, new_role="id") %>%
step_normalize(all_numeric_predictors())
knn_model <- nearest_neighbor(neighbors=5, weight_func="gaussian") %>% # What are the default arguments?
set_engine("kknn") %>%
set_mode("classification")
knn_wf <- workflow() %>%
add_recipe(knn_recipe) %>%
add_model(knn_model)
knn_fit <- knn_wf %>% fit(data = patients_train)
patients_pred <- augment(knn_fit, new_data=patients_test)
conf_mat(patients_pred, truth = class, estimate = .pred_class)
## Truth
## Prediction 0 1
## 0 87 1
## 1 2 47
## # A tibble: 3 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 sens binary 0.978
## 2 spec binary 0.979
## 3 accuracy binary 0.978
# make a column that shows whether or not patient was misclassified
patients_pred <- patients_pred %>% mutate(misclassified = if_else(class != .pred_class, TRUE, FALSE))
# where are the misclassified patients?
ggplot(data=patients_pred, aes(x=bland_chromatin, y=single_epithelial_cell_size))+
geom_point(
data=patients_train,
aes(x=bland_chromatin, y=single_epithelial_cell_size, shape=class),
position="jitter") +
geom_point(
aes(color=misclassified, shape=class),
position="jitter")
Resampling
V-fold Cross-validation
The function vfold_cv
is used to create the
cross-validation folds. What are the default arguments?
Repeated V-fold Cross-validation
We can also set up the folds to repeat many times. How many estimates do we have now?
Fitting the resamples
Now our workflow object can fit the model to each analysis set and
compute metrics on each assessment set. You can also pass
fit_resamples
a metric set, otherwise it will automatically
choose.
# This might take a while!
diab_res <- knn_wf %>% fit_resamples(resamples = patients_folds)
diab_res
## # Resampling results
## # 10-fold cross-validation repeated 5 times
## # A tibble: 50 × 5
## splits id id2 .metrics .notes
## <list> <chr> <chr> <list> <list>
## 1 <split [491/55]> Repeat1 Fold01 <tibble [3 × 4]> <tibble [0 × 3]>
## 2 <split [491/55]> Repeat1 Fold02 <tibble [3 × 4]> <tibble [0 × 3]>
## 3 <split [491/55]> Repeat1 Fold03 <tibble [3 × 4]> <tibble [0 × 3]>
## 4 <split [491/55]> Repeat1 Fold04 <tibble [3 × 4]> <tibble [0 × 3]>
## 5 <split [491/55]> Repeat1 Fold05 <tibble [3 × 4]> <tibble [0 × 3]>
## 6 <split [491/55]> Repeat1 Fold06 <tibble [3 × 4]> <tibble [0 × 3]>
## 7 <split [492/54]> Repeat1 Fold07 <tibble [3 × 4]> <tibble [0 × 3]>
## 8 <split [492/54]> Repeat1 Fold08 <tibble [3 × 4]> <tibble [0 × 3]>
## 9 <split [492/54]> Repeat1 Fold09 <tibble [3 × 4]> <tibble [0 × 3]>
## 10 <split [492/54]> Repeat1 Fold10 <tibble [3 × 4]> <tibble [0 × 3]>
## # ℹ 40 more rows
## # A tibble: 3 × 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.967 50 0.00392 Preprocessor1_Model1
## 2 brier_class binary 0.0267 50 0.00253 Preprocessor1_Model1
## 3 roc_auc binary 0.989 50 0.00217 Preprocessor1_Model1
Tuning hyperparameters
Parameters that we want to mark for tuning can take the
tune()
tag as an argument.
(Do you remember what dist_power
is?)
knn_model <- nearest_neighbor(neighbors = tune(), dist_power = tune()) %>%
set_engine("kknn") %>%
set_mode("classification")
knn_wf <- workflow() %>%
add_recipe(knn_recipe) %>%
add_model(knn_model)
We can take a look at the default parameter sets below:
knn_param <- knn_wf %>% extract_parameter_set_dials()
knn_param %>% extract_parameter_dials("neighbors")
## # Nearest Neighbors (quantitative)
## Range: [1, 15]
## Minkowski Distance Order (quantitative)
## Range: [0.1, 2]
In our example, we said we were interested in dist powers between 1 and 2, and in 3, 5, 7, or 9 nearest neighbors. Here are two ways to create a regular grid like this:
# Way 1: Using the crossing function
# Creates a lot of combinations
my_grid <- crossing(
neighbors = c(3, 5, 7, 9),
dist_power = c(1, 1.25, 1.5, 1.75, 2)
)
ggplot(my_grid, aes(neighbors, dist_power)) + geom_point()
# Way 2: Using the grid_regular function
knn_param <- knn_param %>% update(
neighbors = neighbors(c(3,9)),
dist_power = dist_power(c(1,2))
)
my_grid <- grid_regular(
# the parameter set:
knn_param,
# how many divisions to make at each level. Could also be an integer to give the same number of levels for each parameter
levels = c(neighbors=4, dist_power=5)
)
ggplot(my_grid, aes(neighbors, dist_power)) + geom_point()
We use the tune_grid
function in a similar way to how we
used fit_resamples
.
# Will take a while!
knn_tune <- knn_wf %>%
tune_grid(
patients_folds, # the CV set
grid=my_grid, # the number of levels of each parameter
metrics=metric_set(accuracy) # the metrics you'd like to compute
)
The book lists a bunch of ways to take a look at the different
parameter combinations using plots! Here we’ll just use the
show_best()
function to get a look at the best performing
combinations. “Best” is quantified according to the specified
metric.
## Warning in show_best(., n = 20): No value of `metric` was given; "accuracy"
## will be used.
## # A tibble: 20 × 8
## neighbors dist_power .metric .estimator mean n std_err .config
## <int> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 9 1 accuracy binary 0.971 50 0.00331 Preprocessor1_M…
## 2 7 1 accuracy binary 0.971 50 0.00348 Preprocessor1_M…
## 3 5 1 accuracy binary 0.970 50 0.00346 Preprocessor1_M…
## 4 5 1.5 accuracy binary 0.968 50 0.00351 Preprocessor1_M…
## 5 9 1.25 accuracy binary 0.968 50 0.00373 Preprocessor1_M…
## 6 5 1.75 accuracy binary 0.967 50 0.00383 Preprocessor1_M…
## 7 7 1.25 accuracy binary 0.967 50 0.00390 Preprocessor1_M…
## 8 7 1.75 accuracy binary 0.967 50 0.00386 Preprocessor1_M…
## 9 7 1.5 accuracy binary 0.967 50 0.00382 Preprocessor1_M…
## 10 9 1.75 accuracy binary 0.967 50 0.00385 Preprocessor1_M…
## 11 5 1.25 accuracy binary 0.967 50 0.00385 Preprocessor1_M…
## 12 3 1 accuracy binary 0.966 50 0.00338 Preprocessor1_M…
## 13 5 2 accuracy binary 0.966 50 0.00386 Preprocessor1_M…
## 14 7 2 accuracy binary 0.966 50 0.00386 Preprocessor1_M…
## 15 9 2 accuracy binary 0.966 50 0.00386 Preprocessor1_M…
## 16 9 1.5 accuracy binary 0.966 50 0.00391 Preprocessor1_M…
## 17 3 1.25 accuracy binary 0.961 50 0.00389 Preprocessor1_M…
## 18 3 1.75 accuracy binary 0.959 50 0.00419 Preprocessor1_M…
## 19 3 1.5 accuracy binary 0.958 50 0.00411 Preprocessor1_M…
## 20 3 2 accuracy binary 0.957 50 0.00452 Preprocessor1_M…
Now that we have a best model, we can now update our workflow and fit
it to the training set. We can use select_*
to get the best
performing parameter set:
## # A tibble: 1 × 3
## neighbors dist_power .config
## <int> <dbl> <chr>
## 1 9 1 Preprocessor1_Model04
knn_wf <- knn_wf %>%
finalize_workflow(
parameters = select_best(knn_tune, metric="accuracy")
# could also set the parameters "by hand" here, e.g.
# parameters = c(neighbors = 9, dist_power=1)
)
final_results <- knn_wf %>%
fit(patients_train) %>%
augment(new_data=patients_test)
my_metrics(final_results, truth=class, estimate=.pred_class, event_level="second")
## # A tibble: 3 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 sens binary 0.979
## 2 spec binary 0.978
## 3 accuracy binary 0.978
Should we always choose the best performing model?
Without specifying a grid:
# Recreate the workflow, since we finalized it above
# once it is "finalized" there are no parameters to tune
knn_wf <- workflow() %>%
add_recipe(knn_recipe) %>%
add_model(knn_model)
knn_tune <- knn_wf %>%
tune_grid(
patients_folds, # the CV set
grid=10, # the number of parameter combinations
metrics=metric_set(accuracy) # the metrics you'd like to compute
)
## Warning in show_best(., n = 10): No value of `metric` was given; "accuracy"
## will be used.
Specifying other types of grids
Here are some other types of grids. Maximum entropy is the default
for tune_grid
I believe.
my_grid <- grid_latin_hypercube(
# the parameter set:
knn_param,
# how many points in the grid
size = 13
)
## Warning: `grid_latin_hypercube()` was deprecated in dials 1.3.0.
## ℹ Please use `grid_space_filling()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
my_grid <- grid_max_entropy(
# the parameter set:
knn_param,
# how many points in the grid
size = 13
)
## Warning: `grid_max_entropy()` was deprecated in dials 1.3.0.
## ℹ Please use `grid_space_filling()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.