File size: 1,605 Bytes
d06496c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import hdbscan
import umap
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def load_data():
    # Load data
    embeddings = np.load(r'data\top_cluster_embeddings.npy')
    return embeddings

def get_clusters(embeddings):
    # Get clusters
    umap_embeddings = umap.UMAP(
        n_neighbors=15,
        n_components=15,
        metric='cosine'
        ).fit_transform(embeddings)
    
    cluster = hdbscan.HDBSCAN(
        min_cluster_size=30,
        metric='euclidean',
        cluster_selection_method='eom'
        ).fit(umap_embeddings)
    
    return cluster.labels_

def get_2d_data_for_plotting(embeddings):
    # Get 2D data for plotting
    umap_embeddings = umap.UMAP(
        n_neighbors=15,
        n_components=2,
        metric='cosine'
        ).fit_transform(embeddings)
    
    return umap_embeddings

def plot_clusters(embeddings, cluster_labels):
    umap_data = get_2d_data_for_plotting(embeddings)
    result = pd.DataFrame(umap_data, columns=['x', 'y'])
    result['labels'] = cluster_labels

    # Visualize clusters
    fig, ax = plt.subplots(figsize=(20, 10))
    outliers = result.loc[result.labels == -1, :]
    clustered = result.loc[result.labels != -1, :]
    plt.scatter(outliers.x, outliers.y, color='#BDBDBD', s=0.05)
    plt.scatter(clustered.x, clustered.y, c=clustered.labels, s=0.05, cmap='hsv_r')
    plt.colorbar()
    plt.savefig(r'plots\clusters.png', dpi=300)

def main():
    embeddings = load_data()
    cluster_labels = get_clusters(embeddings)
    plot_clusters(embeddings, cluster_labels)

if __name__ == '__main__':
    main()