K Means Clustering

Intro to Clustering

To get the idea of clustering, we’re going to create some artificial data and perform clustering on this. You can see that we’ve created three distinct clusters in a two variable space. This should be fairly simple for k-means clustering to identify.

set.seed(465)

centers <- tibble(
  cluster = factor(1:3), 
  num_points = c(100, 150, 50),  # number points in each cluster
  x1 = c(5, 0, -3),              # x1 coordinate of cluster center
  x2 = c(-1, 1, -2)              # x2 coordinate of cluster center
)

labelled_points <- 
  centers %>%
  mutate(
    x1 = map2(num_points, x1, rnorm),
    x2 = map2(num_points, x2, rnorm)
  ) %>% 
  select(-num_points) %>% 
  unnest(cols = c(x1, x2))

ggplot(labelled_points, aes(x1, x2, color = cluster)) +
  geom_point(alpha = 0.3)

K Means using kmeans

The k-means model specification is in the tidyclust library. To specify a k-means model in tidymodels, simply choose a value of num_clusters:

library(tidyclust)

kmeans_spec <- k_means(num_clusters = 3)
kmeans_spec
## K Means Cluster Specification (partition)
## 
## Main Arguments:
##   num_clusters = 3
## 
## Computational engine: stats
# note that you don't need to provide the outcome variable, because there isn't one!
kmeans_rec <- recipe(~., data=labelled_points) %>%
  #we don't want to use the cluster variable, but I'm going to use it later so just update the role
  update_role(cluster, new_role="label") %>% 
  # k means uses distances, so we'll normalize the predictors
  step_normalize(all_numeric())

kmeans_wf <- workflow() %>%
  add_model(kmeans_spec) %>%
  add_recipe(kmeans_rec)
set.seed(465)
kmeans_fit <- kmeans_wf %>% fit(data=labelled_points)
kmeans_fit
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: k_means()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 1 Recipe Step
## 
## • step_normalize()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## K-means clustering with 3 clusters of sizes 57, 144, 99
## 
## Cluster means:
##           x1         x2
## 1 -1.2634211 -1.1237364
## 2 -0.3463178  0.8056294
## 3  1.2311592 -0.5248248
## 
## Clustering vector:
##   [1] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
##  [38] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 2 3 3 3 3 3 3 3 3 3 3 3
##  [75] 3 3 3 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 2 3 3 3 3 3 3 3 2 2 2 2 2 1 3 2 2 2 2
## [112] 2 2 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2
## [149] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1
## [186] 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2
## [223] 2 2 2 1 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1
## [260] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1
## [297] 1 1 1 1
## 
## Within cluster sum of squares by cluster:
## [1] 33.81562 63.49849 49.66139
##  (between_SS / total_SS =  75.4 %)
## 
## Available components:
## 
## [1] "cluster"      "centers"      "totss"        "withinss"     "tot.withinss"
## [6] "betweenss"    "size"         "iter"         "ifault"

The model output is hard to read, but gives the centers of each cluster, and the variation within each cluster (what we are trying to minimize). To get a nicer view, we can use tidy() and glance():

tidy(kmeans_fit)
## # A tibble: 3 × 5
##       x1     x2  size withinss cluster
##    <dbl>  <dbl> <int>    <dbl> <fct>  
## 1 -1.26  -1.12     57     33.8 1      
## 2 -0.346  0.806   144     63.5 2      
## 3  1.23  -0.525    99     49.7 3
glance(kmeans_fit)
## # A tibble: 1 × 4
##   totss tot.withinss betweenss  iter
##   <dbl>        <dbl>     <dbl> <int>
## 1   598         147.      451.     3

To get the predicted labels, we can use the augment function. I’ll also plot them to see how the algorithm did.

# Try changing the number of clusters to see what happens!
# kmeans_fit <- workflow() %>%
#   add_model(
#     k_means(num_clusters = 5)
#     ) %>%
#   add_recipe(kmeans_rec) %>% 
#   fit(data=labelled_points)
clustered_points <- kmeans_fit %>% augment(labelled_points)
clustered_points
## # A tibble: 300 × 4
##    cluster    x1      x2 .pred_cluster
##    <fct>   <dbl>   <dbl> <fct>        
##  1 1        6.12  1.11   Cluster_1    
##  2 1        6.73 -2.12   Cluster_1    
##  3 1        7.15 -2.46   Cluster_1    
##  4 1        3.62 -0.366  Cluster_1    
##  5 1        6.23 -2.09   Cluster_1    
##  6 1        5.32  0.149  Cluster_1    
##  7 1        4.22 -1.02   Cluster_1    
##  8 1        4.61  0.0801 Cluster_1    
##  9 1        4.71 -1.98   Cluster_1    
## 10 1        5.25 -2.86   Cluster_1    
## # … with 290 more rows
plot <- clustered_points %>% 
  ggplot(aes(x1, x2)) + 
  geom_point(aes(color=.pred_cluster, shape=cluster), alpha=.5)
plot

Choosing the number of clusters

We generated simulated data that had 3 pre-defined clusters. But if we didn’t know how many clusters there were, how could we choose between them? Typically, we try clustering for many different Ks, plot their within group sum of squares, and choose one based on the “elbow method”. We look for where the plot of SSW flattens out, i.e., the elbow.

set.seed(465)
ssw <- c()
# perform clustering for many different k's:
for (k in 1:8) {
  kmeans_fit <- workflow() %>%
    add_model(
    k_means(num_clusters = k)
    ) %>%
    add_recipe(kmeans_rec) %>% 
    fit(data=labelled_points)
  ssw <- c(ssw, glance(kmeans_fit)$tot.withinss) # get SSW
} 

ggplot(data=NULL, aes(y=ssw, x=1:8)) + 
  geom_line() + geom_point() +
  labs(x="K: number of clusters", y="Total Within Group SS")

As k changes from 1 to 2 and 2 to 3, the sum of squares decreases a lot. It decreases much less after that, and hence flattens out after about 3 or 4 clusters. This is not an exact science. Remember, we typically won’t know how many clusters there are supposed to be!

NBA Example Background and EDA

Basketball teams traditionally assign players to 5 positions: Center, Forward, Small Forward, Shooting Guard and Point Guard. But anyone who watches basketball today knows that these traditional positions do not mean much anymore. Here we’ll look at the stats of different players from the most recent season and see if we can put them into position groups based on role in the team.

This data is from Basketball Reference. All data is per 36 minutes, which is the length of a game.

# I need to clean the names since variable names can't start with a number. 
# clean_names adds an x before those
nba <- read_csv("nba22-23.csv") %>% clean_names()
## Rows: 539 Columns: 30
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (4): Player, Pos, Tm, Player-additional
## dbl (26): Rk, Age, G, GS, MP, FG, FGA, FG%, 3P, 3PA, 3P%, 2P, 2PA, 2P%, FT, ...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

I want to exclude players who didn’t play much this season, as their stats might skew the results. Looking at the histogram of minutes played, I can see that there are a big chunk of players that played 250 min or less. I’ll arbitrarily choose 360 minutes—the equivalent of 10 full games played—as my cutoff. We should have 393 players remaining.

nba %>% ggplot(aes(x=mp)) + geom_histogram(bins=20)

nba <- nba %>% filter(mp >= 360)

Let’s see how the traditional positions are distributed:

nba %>% ggplot(aes(pos)) + geom_bar()

It looks more or less even. Notice that there are some players with more than one category. Normally I would clean that up, but since we aren’t going to use this information, I won’t worry about this for now.

Traditionally, a point guard is a small player that will carry the ball up the court, get a lot of assists, steals, and potentially shoot three point shots. Let’s see how that plays out here:

nba %>% ggplot(aes(ast, stl, color=pos)) + geom_point(position="jitter")

A center, on the other hand, is typically a tall player that will rebound, block shots, score close to the basket and take a lot of free-throws.

nba %>% ggplot(aes(blk, trb, color=pos)) + geom_point(position="jitter")

What about other stats? This one is a mess—there is very little separation between the traditional positions.

# remember the x is there because variable names can't begin with numbers in R
nba %>% ggplot(aes(x3pa, stl, color=pos)) + geom_point(position="jitter")

Clustering based on Offensive Stats

Let’s see if we can cluster players that have similar offensive statistics. We could add in other stats here, or consider defensive stats to determine overall play styles, but this is meant to be a simple example.

set.seed(465) # for reproducibility
nba_off <- nba %>% select(player, pos, pts, ast, tov, orb, ft, x2p, x3p) 

First, let’s see if there is a big difference in these stats based on traditional position:

nba_off %>% 
  select(-player) %>%
  group_by(pos) %>%
  summarize(across(everything(), mean)) %>%
  filter(pos %in% c("C", "PF", "SF", "SG", "PG"))
## # A tibble: 5 × 8
##   pos     pts   ast   tov   orb    ft   x2p   x3p
##   <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 C      16.1  2.54  2.00 3.48   2.84  5.64 0.678
## 2 PF     15.3  2.73  1.76 1.89   2.30  4.07 1.64 
## 3 PG     16.7  6.24  2.38 0.837  2.71  3.95 2.01 
## 4 SF     15.5  2.77  1.66 1.37   2.27  3.62 1.98 
## 5 SG     16.6  3.51  1.90 0.893  2.35  3.64 2.33

There are some differences, but many of the categories are similar, e.g., scoring seems to be spread evenly between the traditional positions.

Next, we’ll perform the k-means clustering. Notice again that we need to scale the variables.

set.seed(465)
nba_rec <- recipe(~., data=nba_off) %>%
  update_role(player, new_role="player") %>%
  update_role(pos, new_role="position") %>%
  step_normalize(all_numeric())

nba_spec <- k_means(num_clusters=6)

nba_wf <- workflow() %>%
  add_model(nba_spec) %>%
  add_recipe(nba_rec)

nba.clusters <- nba_wf %>%
  fit(data=nba_off) %>%
  augment(nba_off)

Let’s take a look at some of the players in each cluster and explore this.

nba.clusters %>% 
  group_by(.pred_cluster) %>%
  select(-c(player, pos)) %>%
  summarize(across(everything(), mean))
## # A tibble: 6 × 8
##   .pred_cluster   pts   ast   tov   orb    ft   x2p   x3p
##   <fct>         <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 Cluster_1      14.8  2.18  1.89 3.95   2.47  5.73 0.297
## 2 Cluster_2      13.2  2.66  1.34 0.939  1.35  2.13 2.51 
## 3 Cluster_3      14.0  2.63  1.69 1.81   2     3.85 1.43 
## 4 Cluster_4      15.1  6.88  2.56 0.842  2.27  3.95 1.63 
## 5 Cluster_5      28.6  6.45  3.21 1.47   6.97  8.46 1.56 
## 6 Cluster_6      22.5  4.31  2.53 1.02   3.99  5.42 2.56
for(k in 1:6){
  print(
    nba.clusters %>% filter(.pred_cluster == paste("Cluster", as.character(k), sep="_")) %>% slice_sample(n=5)
  )
}
## # A tibble: 5 × 10
##   player            pos     pts   ast   tov   orb    ft   x2p   x3p .pred_clus…¹
##   <chr>             <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>       
## 1 Jonas Valančiūnas C      20.4   2.6   2.9   4.1   3.4   7.5   0.7 Cluster_1   
## 2 Chris Boucher     PF     16.8   0.7   1     3.7   3     4.7   1.5 Cluster_1   
## 3 Bruno Fernando    C      13.6   2.8   2.2   4.8   3.1   5.2   0   Cluster_1   
## 4 Kevon Looney      C      10.6   3.8   0.8   5     1.7   4.4   0   Cluster_1   
## 5 Clint Capela      C      16.2   1.2   1.1   5.4   1.6   7.3   0   Cluster_1   
## # … with abbreviated variable name ¹​.pred_cluster
## # A tibble: 5 × 10
##   player            pos     pts   ast   tov   orb    ft   x2p   x3p .pred_clus…¹
##   <chr>             <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>       
## 1 Doug McDermott    SF     17.9   2.5   1.6   0.7   1.5   3.2   3.4 Cluster_2   
## 2 Kevin Knox        PF-SF  15.7   1.2   1.9   1.3   1.7   3.3   2.5 Cluster_2   
## 3 De'Anthony Melton PG     13.1   3.3   1.7   1.2   1.2   2     2.6 Cluster_2   
## 4 Danuel House Jr.  SF     12     2     1.2   0.4   2     2.5   1.7 Cluster_2   
## 5 JT Thor           PF      9.7   1.3   1.6   1.3   1.2   2.1   1.4 Cluster_2   
## # … with abbreviated variable name ¹​.pred_cluster
## # A tibble: 5 × 10
##   player            pos     pts   ast   tov   orb    ft   x2p   x3p .pred_clus…¹
##   <chr>             <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>       
## 1 David Roddy       PF     13.3   1.7   1.6   1.4   1.2   3.3   1.9 Cluster_3   
## 2 Isaiah Stewart    C      14.4   1.8   1.7   2.9   2.8   3.3   1.7 Cluster_3   
## 3 Haywood Highsmith SF      8.8   1.6   1.7   2.1   0.5   2.1   1.4 Cluster_3   
## 4 Keita Bates-Diop  SF     16     2.6   1.4   1.6   3     4.5   1.4 Cluster_3   
## 5 Lamar Stevens     PF     10.5   1.1   0.9   1.5   1.3   3.2   1   Cluster_3   
## # … with abbreviated variable name ¹​.pred_cluster
## # A tibble: 5 × 10
##   player         pos     pts   ast   tov   orb    ft   x2p   x3p .pred_cluster
##   <chr>          <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>        
## 1 Ty Jerome      SG     13.6   6     1.3   0.3   1.7   3.7   1.5 Cluster_4    
## 2 Victor Oladipo SG     14.6   4.8   2.8   0.5   1.9   2.9   2.3 Cluster_4    
## 3 John Wall      PG     18.4   8.5   3.8   0.7   3.7   5     1.6 Cluster_4    
## 4 Reggie Jackson PG     15.1   5.1   2.5   0.5   1.5   3.6   2.1 Cluster_4    
## 5 Ish Smith      PG      9.8   9     4     0.5   0.2   4.5   0.2 Cluster_4    
## # A tibble: 5 × 10
##   player           pos     pts   ast   tov   orb    ft   x2p   x3p .pred_cluster
##   <chr>            <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>        
## 1 LeBron James     PF     29.3   6.9   3.3   1.2   4.6   9     2.2 Cluster_5    
## 2 Zion Williamson  PF     28.4   5     3.7   2.2   6.7  10.5   0.3 Cluster_5    
## 3 Domantas Sabonis C      19.9   7.5   3     3.3   4.3   7.2   0.4 Cluster_5    
## 4 Joel Embiid      C      34.4   4.3   3.6   1.8  10.4  10.4   1   Cluster_5    
## 5 Jimmy Butler     SF     24.7   5.7   1.7   2.4   7.9   7.5   0.6 Cluster_5    
## # A tibble: 5 × 10
##   player           pos     pts   ast   tov   orb    ft   x2p   x3p .pred_cluster
##   <chr>            <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>        
## 1 Kyrie Irving     PG     26.1   5.3   2.1   0.9   4     6.5   3   Cluster_6    
## 2 Jordan Poole     PG     24.5   5.4   3.7   0.5   5.3   4.9   3.1 Cluster_6    
## 3 Stephen Curry    PG     30.6   6.5   3.3   0.7   4.8   5.3   5.1 Cluster_6    
## 4 Kevin Porter Jr. PG     20.1   6     3.3   1.4   3.7   4.4   2.5 Cluster_6    
## 5 Klay Thompson    SG     23.8   2.6   1.9   0.6   1.8   3.9   4.8 Cluster_6

Choosing the number of clusters

I chose 6 clusters arbitrarily. Since it’s not clear exactly what the right number should be, let’s try a few different values and see.

set.seed(465)
ssw <- c()
K_max <- 12
for (k in 1:K_max) {
  nba_clusts <- workflow() %>%
    add_model(
      k_means(num_clusters = k)
    ) %>%
    add_recipe(nba_rec) %>%
    fit(data=nba_off) 
  ssw <- c(ssw, glance(nba_clusts)$tot.withinss) # get SSW
} 

ggplot(data=NULL, aes(y=ssw, x=1:K_max)) + geom_line() +
  labs(x = "K", y = "SSE")

# To get a better picture, we can look at how much the SSW changed from the previous clustering
tibble(k=2:(K_max+1), ssw_diff = ssw - lead(ssw)) %>%
  filter(k<K_max+1) %>%
  ggplot(aes(k, ssw_diff)) + geom_line() +
  labs(x = "K", y = "SSE Rolling Difference", title = "A Clearer Picture")

To me this looks like 4-5 is the right number. And that aligns somewhat with what we saw previously. A few of the groups didn’t look that different!

set.seed(465) # for reproducibility

nba_spec <- k_means(num_clusters=4)

nba_clusters4 <- workflow() %>%
  add_model(nba_spec) %>%
  add_recipe(nba_rec) %>%
  fit(data=nba_off) %>%
  augment(nba_off)

nba_clusters4 %>% 
  group_by(.pred_cluster) %>%
  summarize(pts=mean(pts), ast=mean(ast), tov=mean(tov), orb=mean(orb), ft=mean(ft), x2p=mean(x2p), x3p=mean(x3p))
## # A tibble: 4 × 8
##   .pred_cluster   pts   ast   tov   orb    ft   x2p   x3p
##   <fct>         <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 Cluster_1      14.7  2.21  1.87  3.67  2.39  5.51 0.449
## 2 Cluster_2      13.8  2.55  1.42  1.18  1.62  2.73 2.23 
## 3 Cluster_3      14.4  5.81  2.34  1.06  2.11  3.91 1.47 
## 4 Cluster_4      24.1  5.30  2.85  1.14  4.83  6.21 2.28
nba_clusters4 %>% filter(.pred_cluster == "Cluster_2") %>% head(n=10)
## # A tibble: 10 × 10
##    player           pos     pts   ast   tov   orb    ft   x2p   x3p .pred_clus…¹
##    <chr>            <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>       
##  1 Ochai Agbaji     SG     13.9   2     1.2   1.3   1.7   2.5   2.4 Cluster_2   
##  2 Santi Aldama     PF     14.9   2.1   1.3   1.8   2.3   3.3   2   Cluster_2   
##  3 Grayson Allen    SG     13.7   3     1.3   1.1   2.1   1.8   2.7 Cluster_2   
##  4 OG Anunoby       SF     17     2     2     1.4   2.1   4.2   2.1 Cluster_2   
##  5 Mo Bamba         C      15.1   2.2   1.4   2.7   1.9   3.2   2.2 Cluster_2   
##  6 Harrison Barnes  PF     16.6   1.7   1.2   1.2   4.7   3.3   1.8 Cluster_2   
##  7 Will Barton      SG     13.7   4.1   1.7   0.5   1.3   2.6   2.4 Cluster_2   
##  8 Keita Bates-Diop SF     16     2.6   1.4   1.6   3     4.5   1.4 Cluster_2   
##  9 Nicolas Batum    PF     10.1   2.5   1.1   1.3   0.7   0.8   2.6 Cluster_2   
## 10 Darius Bazley    PF     13.1   2.2   1.6   2     2.2   3.7   1.2 Cluster_2   
## # … with abbreviated variable name ¹​.pred_cluster

There are a number of ways we could explore this further! This data is high-dimensional, so it’s much more difficult to see the relationships than previously.

nba_clusters4 %>% ggplot(aes(pts, ast)) + geom_point()

If you know the NBA, take a look at the players in each cluster and see if the similarities make sense!

Clustering with PCA

If time

set.seed(465) # for reproducibility

pca_rec <- nba_rec %>% step_pca(all_numeric())

pca_predclusters4 <- workflow() %>%
  add_model(
    k_means(num_clusters=4)
  ) %>%
  add_recipe(pca_rec) %>%
  fit(data=nba_off) %>%
  predict(nba_off)

pca_clusters4 <- bind_cols(
  bake(prep(pca_rec), new_data=NULL),
  pca_predclusters4)

pca_clusters4 %>% ggplot(aes(PC1, PC2, color=.pred_cluster)) + geom_point()