Add script to generate dataset of embeddings and perplexities. Add script to generate t-SNE plot for embedding and perplexity visualization.
Browse files- get_embeddings_and_perplexity.py +47 -0
- tsne_plot.py +66 -0
get_embeddings_and_perplexity.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import kenlm
|
3 |
+
from datasets import load_dataset
|
4 |
+
from tqdm import tqdm
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
|
9 |
+
|
10 |
+
TOTAL_SENTENCES = 20000
|
11 |
+
def pp(log_score, length):
|
12 |
+
return 10.0 ** (-log_score / length)
|
13 |
+
|
14 |
+
|
15 |
+
embedder = "distiluse-base-multilingual-cased-v1"
|
16 |
+
embedder_model = SentenceTransformer(embedder)
|
17 |
+
embedding_shape = embedder_model.encode(["foo"])[0].shape[0]
|
18 |
+
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
19 |
+
model = kenlm.Model("es.arpa.bin")
|
20 |
+
mc4 = load_dataset("mc4", "es", streaming=True)
|
21 |
+
count = 0
|
22 |
+
embeddings = []
|
23 |
+
lenghts = []
|
24 |
+
perplexities = []
|
25 |
+
sentences = []
|
26 |
+
|
27 |
+
for sample in tqdm(mc4["train"].shuffle(buffer_size=100_000), total=416057992):
|
28 |
+
lines = sample["text"].split("\n")
|
29 |
+
for line in lines:
|
30 |
+
count += 1
|
31 |
+
log_score = model.score(line)
|
32 |
+
length = len(line.split()) + 1
|
33 |
+
embedding = embedder_model.encode([line])[0]
|
34 |
+
embeddings.append(embedding.tolist())
|
35 |
+
perplexities.append(pp(log_score, length))
|
36 |
+
lenghts.append(length)
|
37 |
+
sentences.append(line)
|
38 |
+
if count == TOTAL_SENTENCES:
|
39 |
+
break
|
40 |
+
if count == TOTAL_SENTENCES:
|
41 |
+
embeddings = np.array(embeddings)
|
42 |
+
df = pd.DataFrame({"sentence": sentences, "length": lenghts, "perplexity": perplexities})
|
43 |
+
for dim in range(embedding_shape):
|
44 |
+
df[f"dim_{dim}"] = embeddings[:, dim]
|
45 |
+
df.to_csv("mc4-es-perplexity-sentences.tsv", index=None, sep="\t")
|
46 |
+
print("DONE!")
|
47 |
+
break
|
tsne_plot.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
from typing import Any, Optional
|
4 |
+
|
5 |
+
import bokeh
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from bokeh.models import ColumnDataSource, HoverTool
|
9 |
+
from bokeh.plotting import figure, output_file, save
|
10 |
+
from bokeh.transform import factor_cmap
|
11 |
+
from bokeh.palettes import Cividis256 as Pallete
|
12 |
+
from bokeh.resources import CDN
|
13 |
+
from bokeh.embed import file_html
|
14 |
+
from sklearn.manifold import TSNE
|
15 |
+
|
16 |
+
|
17 |
+
logging.basicConfig(level = logging.INFO)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
SEED = 0
|
20 |
+
|
21 |
+
def get_tsne_embeddings(embeddings: np.ndarray, perplexity: int=30, n_components: int=2, init: str='pca', n_iter: int=5000, random_state: int=SEED) -> np.ndarray:
|
22 |
+
tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
|
23 |
+
return tsne.fit_transform(embeddings)
|
24 |
+
|
25 |
+
def draw_interactive_scatter_plot(texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray) -> Any:
|
26 |
+
# Normalize values to range between 0-255, to assign a color for each value
|
27 |
+
max_value = values.max()
|
28 |
+
min_value = values.min()
|
29 |
+
values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
|
30 |
+
values_color_set = sorted(values_color)
|
31 |
+
|
32 |
+
values_list = values.astype(str).tolist()
|
33 |
+
values_set = sorted(values_list)
|
34 |
+
|
35 |
+
source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, perplexity=values_list))
|
36 |
+
hover = HoverTool(tooltips=[('Sentence', '@text{safe}'), ('Perplexity', '@perplexity')])
|
37 |
+
p = figure(plot_width=1200, plot_height=1200, tools=[hover], title='Sentences')
|
38 |
+
p.circle(
|
39 |
+
'x', 'y', size=10, source=source, fill_color=factor_cmap('perplexity', palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set))
|
40 |
+
return p
|
41 |
+
|
42 |
+
def generate_plot(tsv: str, output_file_name: str, sample: Optional[int]):
|
43 |
+
logger.info("Loading dataset in memory")
|
44 |
+
df = pd.read_csv(tsv, sep="\t")
|
45 |
+
if sample:
|
46 |
+
df = df.sample(sample, random_state=SEED)
|
47 |
+
logger.info(f"Dataset contains {df.shape[0]} sentences")
|
48 |
+
embeddings = df[sorted([col for col in df.columns if col.startswith("dim")], key=lambda x: int(x.split("_")[-1]))].values
|
49 |
+
logger.info(f"Running t-SNE")
|
50 |
+
tsne_embeddings = get_tsne_embeddings(embeddings)
|
51 |
+
logger.info(f"Generating figure")
|
52 |
+
plot = draw_interactive_scatter_plot(df["sentence"].values, tsne_embeddings[:, 0], tsne_embeddings[:, 1], df["perplexity"].values)
|
53 |
+
output_file(output_file_name)
|
54 |
+
save(plot)
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
parser = argparse.ArgumentParser(description="Embeddings t-SNE plot")
|
61 |
+
parser.add_argument("--tsv", type=str, help="Path to tsv file with columns 'text', 'perplexity' and N 'dim_<i> columns for each embdeding dimension.'")
|
62 |
+
parser.add_argument("--output_file", type=str, help="Path to the output HTML file for the interactive plot.", default="perplexity_colored_embeddings.html")
|
63 |
+
parser.add_argument("--sample", type=int, help="Number of sentences to use", default=None)
|
64 |
+
|
65 |
+
args = parser.parse_args()
|
66 |
+
generate_plot(args.tsv, args.output_file, args.sample)
|