hgrif commited on
Commit
1431830
β€’
1 Parent(s): a3cb5cc

Revert "Split in app & library"

Browse files

This reverts commit a3cb5cc4a75e472669bb6f4837e80edaadfb220d.

app.py CHANGED
@@ -1,10 +1,17 @@
1
  import copy
 
 
 
 
 
 
2
 
 
 
 
3
  import streamlit as st
4
- from rhyme_with_ai.rhyme import query_rhyme_words
5
- from rhyme_with_ai.rhyme_generator import RhymeGenerator
6
- from rhyme_with_ai.utils import color_new_words, sanitize
7
- from transformers import TFAutoModelForMaskedLM, AutoTokenizer
8
 
9
 
10
  DEFAULT_QUERY = "Machines will take over the world soon"
@@ -93,6 +100,300 @@ def display_output(status_text, query, current_sentences, previous_sentences):
93
  query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
94
  )
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  if __name__ == "__main__":
98
  main()
 
1
  import copy
2
+ import functools
3
+ import itertools
4
+ import logging
5
+ import random
6
+ import string
7
+ from typing import List, Optional
8
 
9
+ import requests
10
+ import numpy as np
11
+ import tensorflow as tf
12
  import streamlit as st
13
+ from gazpacho import Soup, get
14
+ from transformers import AutoTokenizer, TFAutoModelForMaskedLM
 
 
15
 
16
 
17
  DEFAULT_QUERY = "Machines will take over the world soon"
 
100
  query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
101
  )
102
 
103
+ class TokenWeighter:
104
+ def __init__(self, tokenizer):
105
+ self.tokenizer_ = tokenizer
106
+ self.proba = self.get_token_proba()
107
+
108
+ def get_token_proba(self):
109
+ valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
110
+ return valid_token_mask
111
+
112
+ def _filter_short_partial(self, vocab):
113
+ valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
114
+ is_valid = np.zeros(len(vocab.keys()))
115
+ is_valid[valid_token_ids] = 1
116
+ return is_valid
117
+
118
+
119
+ class RhymeGenerator:
120
+ def __init__(
121
+ self,
122
+ model: TFAutoModelForMaskedLM,
123
+ tokenizer: AutoTokenizer,
124
+ token_weighter: TokenWeighter = None,
125
+ ):
126
+ """Generate rhymes.
127
+
128
+ Parameters
129
+ ----------
130
+ model : Model for masked language modelling
131
+ tokenizer : Tokenizer for model
132
+ token_weighter : Class that weighs tokens
133
+ """
134
+
135
+ self.model = model
136
+ self.tokenizer = tokenizer
137
+ if token_weighter is None:
138
+ token_weighter = TokenWeighter(tokenizer)
139
+ self.token_weighter = token_weighter
140
+ self._logger = logging.getLogger(__name__)
141
+
142
+ self.tokenized_rhymes_ = None
143
+ self.position_probas_ = None
144
+
145
+ # Easy access.
146
+ self.comma_token_id = self.tokenizer.encode(",", add_special_tokens=False)[0]
147
+ self.period_token_id = self.tokenizer.encode(".", add_special_tokens=False)[0]
148
+ self.mask_token_id = self.tokenizer.mask_token_id
149
+
150
+ def start(self, query: str, rhyme_words: List[str]) -> None:
151
+ """Start the sentence generator.
152
+
153
+ Parameters
154
+ ----------
155
+ query : Seed sentence
156
+ rhyme_words : Rhyme words for next sentence
157
+ """
158
+ # TODO: What if no content?
159
+ self._logger.info("Got sentence %s", query)
160
+ tokenized_rhymes = [
161
+ self._initialize_rhymes(query, rhyme_word) for rhyme_word in rhyme_words
162
+ ]
163
+ # Make same length.
164
+ self.tokenized_rhymes_ = tf.keras.preprocessing.sequence.pad_sequences(
165
+ tokenized_rhymes, padding="post", value=self.tokenizer.pad_token_id
166
+ )
167
+ p = self.tokenized_rhymes_ == self.tokenizer.mask_token_id
168
+ self.position_probas_ = p / p.sum(1).reshape(-1, 1)
169
+
170
+ def _initialize_rhymes(self, query: str, rhyme_word: str) -> List[int]:
171
+ """Initialize the rhymes.
172
+
173
+ * Tokenize input
174
+ * Append a comma if the sentence does not end in it (might add better predictions as it
175
+ shows the two sentence parts are related)
176
+ * Make second line as long as the original
177
+ * Add a period
178
+
179
+ Parameters
180
+ ----------
181
+ query : First line
182
+ rhyme_word : Last word for second line
183
+
184
+ Returns
185
+ -------
186
+ Tokenized rhyme lines
187
+ """
188
+
189
+ query_token_ids = self.tokenizer.encode(query, add_special_tokens=False)
190
+ rhyme_word_token_ids = self.tokenizer.encode(
191
+ rhyme_word, add_special_tokens=False
192
+ )
193
+
194
+ if query_token_ids[-1] != self.comma_token_id:
195
+ query_token_ids.append(self.comma_token_id)
196
+
197
+ magic_correction = len(rhyme_word_token_ids) + 1 # 1 for comma
198
+ return (
199
+ query_token_ids
200
+ + [self.tokenizer.mask_token_id] * (len(query_token_ids) - magic_correction)
201
+ + rhyme_word_token_ids
202
+ + [self.period_token_id]
203
+ )
204
+
205
+ def mutate(self):
206
+ """Mutate the current rhymes.
207
+
208
+ Returns
209
+ -------
210
+ Mutated rhymes
211
+ """
212
+ self.tokenized_rhymes_ = self._mutate(
213
+ self.tokenized_rhymes_, self.position_probas_, self.token_weighter.proba
214
+ )
215
+
216
+ rhymes = []
217
+ for i in range(len(self.tokenized_rhymes_)):
218
+ rhymes.append(
219
+ self.tokenizer.convert_tokens_to_string(
220
+ self.tokenizer.convert_ids_to_tokens(
221
+ self.tokenized_rhymes_[i], skip_special_tokens=True
222
+ )
223
+ )
224
+ )
225
+ return rhymes
226
+
227
+ def _mutate(
228
+ self,
229
+ tokenized_rhymes: np.ndarray,
230
+ position_probas: np.ndarray,
231
+ token_id_probas: np.ndarray,
232
+ ) -> np.ndarray:
233
+
234
+ replacements = []
235
+ for i in range(tokenized_rhymes.shape[0]):
236
+ mask_idx, masked_token_ids = self._mask_token(
237
+ tokenized_rhymes[i], position_probas[i]
238
+ )
239
+ tokenized_rhymes[i] = masked_token_ids
240
+ replacements.append(mask_idx)
241
+
242
+ predictions = self._predict_masked_tokens(tokenized_rhymes)
243
+
244
+ for i, token_ids in enumerate(tokenized_rhymes):
245
+ replace_ix = replacements[i]
246
+ token_ids[replace_ix] = self._draw_replacement(
247
+ predictions[i], token_id_probas, replace_ix
248
+ )
249
+ tokenized_rhymes[i] = token_ids
250
+
251
+ return tokenized_rhymes
252
+
253
+ def _mask_token(self, token_ids, position_probas):
254
+ """Mask line and return index to update."""
255
+ token_ids = self._mask_repeats(token_ids, position_probas)
256
+ ix = self._locate_mask(token_ids, position_probas)
257
+ token_ids[ix] = self.mask_token_id
258
+ return ix, token_ids
259
+
260
+ def _locate_mask(self, token_ids, position_probas):
261
+ """Update masks or a random token."""
262
+ if self.mask_token_id in token_ids:
263
+ # Already masks present, just return the last.
264
+ # We used to return thee first but this returns worse predictions.
265
+ return np.where(token_ids == self.tokenizer.mask_token_id)[0][-1]
266
+ return np.random.choice(range(len(position_probas)), p=position_probas)
267
+
268
+ def _mask_repeats(self, token_ids, position_probas):
269
+ """Repeated tokens are generally of less quality."""
270
+ repeats = [
271
+ ii for ii, ids in enumerate(pairwise(token_ids[:-2])) if ids[0] == ids[1]
272
+ ]
273
+ for ii in repeats:
274
+ if position_probas[ii] > 0:
275
+ token_ids[ii] = self.mask_token_id
276
+ if position_probas[ii + 1] > 0:
277
+ token_ids[ii + 1] = self.mask_token_id
278
+ return token_ids
279
+
280
+ def _predict_masked_tokens(self, tokenized_rhymes):
281
+ return self.model(tf.constant(tokenized_rhymes))[0]
282
+
283
+ def _draw_replacement(self, predictions, token_probas, replace_ix):
284
+ """Get probability, weigh and draw."""
285
+ # TODO (HG): Can't we softmax when calling the model?
286
+ probas = tf.nn.softmax(predictions[replace_ix]).numpy() * token_probas
287
+ probas /= probas.sum()
288
+ return np.random.choice(range(len(probas)), p=probas)
289
+
290
+
291
+
292
+ def query_rhyme_words(sentence: str, n_rhymes: int, language:str="english") -> List[str]:
293
+ """Returns a list of rhyme words for a sentence.
294
+
295
+ Parameters
296
+ ----------
297
+ sentence : Sentence that may end with punctuation
298
+ n_rhymes : Maximum number of rhymes to return
299
+
300
+ Returns
301
+ -------
302
+ List[str] -- List of words that rhyme with the final word
303
+ """
304
+ last_word = find_last_word(sentence)
305
+ if language == "english":
306
+ return query_datamuse_api(last_word, n_rhymes)
307
+ elif language == "dutch":
308
+ return mick_rijmwoordenboek(last_word, n_rhymes)
309
+ else:
310
+ raise NotImplementedError(f"Unsupported language ({language}) expected 'english' or 'dutch'.")
311
+
312
+
313
+ def query_datamuse_api(word: str, n_rhymes: Optional[int] = None) -> List[str]:
314
+ """Query the DataMuse API.
315
+
316
+ Parameters
317
+ ----------
318
+ word : Word to rhyme with
319
+ n_rhymes : Max rhymes to return
320
+
321
+ Returns
322
+ -------
323
+ Rhyme words
324
+ """
325
+ out = requests.get(
326
+ "https://api.datamuse.com/words", params={"rel_rhy": word}
327
+ ).json()
328
+ words = [_["word"] for _ in out]
329
+ if n_rhymes is None:
330
+ return words
331
+ return words[:n_rhymes]
332
+
333
+
334
+ @functools.lru_cache(maxsize=128, typed=False)
335
+ def mick_rijmwoordenboek(word: str, n_words: int):
336
+ url = f"https://rijmwoordenboek.nl/rijm/{word}"
337
+ html = get(url)
338
+ soup = Soup(html)
339
+
340
+ results = soup.find("div", {"id": "rhymeResultsWords"}).html.split("<br />")
341
+
342
+ # clean up
343
+ results = [r.replace("\n", "").replace(" ", "") for r in results]
344
+
345
+ # filter html and empty strings
346
+ results = [r for r in results if ("<" not in r) and (len(r) > 0)]
347
+
348
+ return random.sample(results, min(len(results), n_words))
349
+
350
+
351
+ def color_new_words(new: str, old: str, color: str = "#eefa66") -> str:
352
+ """Color new words in strings with a span."""
353
+
354
+ def find_diff(new_, old_):
355
+ return [ii for ii, (n, o) in enumerate(zip(new_, old_)) if n != o]
356
+
357
+ new_words = new.split()
358
+ old_words = old.split()
359
+ forward = find_diff(new_words, old_words)
360
+ backward = find_diff(new_words[::-1], old_words[::-1])
361
+
362
+ if not forward or not backward:
363
+ # No difference
364
+ return new
365
+
366
+ start, end = forward[0], len(new_words) - backward[0]
367
+ return (
368
+ " ".join(new_words[:start])
369
+ + " "
370
+ + f'<span style="background-color: {color}">'
371
+ + " ".join(new_words[start:end])
372
+ + "</span>"
373
+ + " "
374
+ + " ".join(new_words[end:])
375
+ )
376
+
377
+
378
+ def find_last_word(s):
379
+ """Find the last word in a string."""
380
+ # Note: will break on \n, \r, etc.
381
+ alpha_only_sentence = "".join([c for c in s if (c.isalpha() or (c == " "))]).strip()
382
+ return alpha_only_sentence.split()[-1]
383
+
384
+
385
+ def pairwise(iterable):
386
+ """s -> (s0,s1), (s1,s2), (s2, s3), ..."""
387
+ # https://stackoverflow.com/questions/5434891/iterate-a-list-as-pair-current-next-in-python
388
+ a, b = itertools.tee(iterable)
389
+ next(b, None)
390
+ return zip(a, b)
391
+
392
+
393
+ def sanitize(s):
394
+ """Remove punctuation from a string."""
395
+ return s.translate(str.maketrans("", "", string.punctuation))
396
+
397
 
398
  if __name__ == "__main__":
399
  main()
requirements.txt CHANGED
@@ -2,5 +2,4 @@ gazpacho
2
  numpy
3
  requests
4
  tensorflow
5
- transformers
6
- -e .
 
2
  numpy
3
  requests
4
  tensorflow
5
+ transformers
 
setup.cfg DELETED
@@ -1,17 +0,0 @@
1
- [aliases]
2
- test=pytest
3
-
4
- [flake8]
5
- max-line-length = 88
6
-
7
- [tool:pytest]
8
- addopts = --cov=src --cov-report=xml:test-coverage.xml --nunitxml test-output.xml -vv
9
-
10
- [bumpversion]
11
- current_version = 0.1
12
- commit = True
13
- tag = True
14
-
15
- [bumpversion:file:setup.py]
16
- search = version='{current_version}'
17
- replace = version='{new_version}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
setup.py DELETED
@@ -1,58 +0,0 @@
1
- import os
2
- from setuptools import setup, find_packages
3
-
4
- with open("README.md") as readme_file:
5
- readme = readme_file.read()
6
-
7
- requirements = [
8
- "numpy",
9
- "pandas",
10
- "requests",
11
- "tensorflow",
12
- "transformers",
13
- ]
14
-
15
- extra_requirements = {
16
- "dev": [
17
- "black",
18
- "bump2version",
19
- "coverage",
20
- "gazpacho",
21
- "twine",
22
- "pre-commit",
23
- "pylint",
24
- "pytest",
25
- ]
26
- }
27
-
28
- setup_requirements = ["pytest-runner"]
29
-
30
- test_requirements = ["pytest", "pytest-cov", "pytest-nunit"]
31
-
32
- BUILD_ID = os.environ.get("BUILD_BUILDID", "0")
33
-
34
- setup(
35
- author="Rens Dimmendaal & Henk Griffioen",
36
- author_email="[email protected]",
37
- classifiers=[
38
- "Development Status :: 2 - Pre-Alpha",
39
- "Intended Audience :: Developers",
40
- "License :: OSI Approved :: MIT License",
41
- "Natural Language :: English",
42
- "Programming Language :: Python :: 3.7",
43
- ],
44
- description="Generate text",
45
- install_requires=requirements,
46
- extras_require=extra_requirements,
47
- long_description=readme,
48
- include_package_data=True,
49
- keywords="rhyme",
50
- name="rhyme_with_ai",
51
- packages=find_packages(include=["src"]),
52
- package_dir={"": "src"},
53
- setup_requires=setup_requirements,
54
- test_suite="tests",
55
- tests_require=test_requirements,
56
- version="0.1" + "." + BUILD_ID,
57
- zip_safe=False,
58
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rhyme_with_ai/__init__.py DELETED
File without changes
src/rhyme_with_ai/rhyme.py DELETED
@@ -1,67 +0,0 @@
1
- import functools
2
- import random
3
- from typing import List, Optional
4
-
5
- import requests
6
- from gazpacho import Soup, get
7
-
8
- from rhyme_with_ai.utils import find_last_word
9
-
10
-
11
- def query_rhyme_words(sentence: str, n_rhymes: int, language:str="english") -> List[str]:
12
- """Returns a list of rhyme words for a sentence.
13
-
14
- Parameters
15
- ----------
16
- sentence : Sentence that may end with punctuation
17
- n_rhymes : Maximum number of rhymes to return
18
-
19
- Returns
20
- -------
21
- List[str] -- List of words that rhyme with the final word
22
- """
23
- last_word = find_last_word(sentence)
24
- if language == "english":
25
- return query_datamuse_api(last_word, n_rhymes)
26
- elif language == "dutch":
27
- return mick_rijmwoordenboek(last_word, n_rhymes)
28
- else:
29
- raise NotImplementedError(f"Unsupported language ({language}) expected 'english' or 'dutch'.")
30
-
31
-
32
- def query_datamuse_api(word: str, n_rhymes: Optional[int] = None) -> List[str]:
33
- """Query the DataMuse API.
34
-
35
- Parameters
36
- ----------
37
- word : Word to rhyme with
38
- n_rhymes : Max rhymes to return
39
-
40
- Returns
41
- -------
42
- Rhyme words
43
- """
44
- out = requests.get(
45
- "https://api.datamuse.com/words", params={"rel_rhy": word}
46
- ).json()
47
- words = [_["word"] for _ in out]
48
- if n_rhymes is None:
49
- return words
50
- return words[:n_rhymes]
51
-
52
-
53
- @functools.lru_cache(maxsize=128, typed=False)
54
- def mick_rijmwoordenboek(word: str, n_words: int):
55
- url = f"https://rijmwoordenboek.nl/rijm/{word}"
56
- html = get(url)
57
- soup = Soup(html)
58
-
59
- results = soup.find("div", {"id": "rhymeResultsWords"}).html.split("<br />")
60
-
61
- # clean up
62
- results = [r.replace("\n", "").replace(" ", "") for r in results]
63
-
64
- # filter html and empty strings
65
- results = [r for r in results if ("<" not in r) and (len(r) > 0)]
66
-
67
- return random.sample(results, min(len(results), n_words))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rhyme_with_ai/rhyme_generator.py DELETED
@@ -1,181 +0,0 @@
1
- import logging
2
- from typing import List
3
-
4
- import numpy as np
5
- import tensorflow as tf
6
- from transformers import TFAutoModelForMaskedLM, AutoTokenizer
7
-
8
- from rhyme_with_ai.utils import pairwise
9
- from rhyme_with_ai.token_weighter import TokenWeighter
10
-
11
-
12
- class RhymeGenerator:
13
- def __init__(
14
- self,
15
- model: TFAutoModelForMaskedLM,
16
- tokenizer: AutoTokenizer,
17
- token_weighter: TokenWeighter = None,
18
- ):
19
- """Generate rhymes.
20
-
21
- Parameters
22
- ----------
23
- model : Model for masked language modelling
24
- tokenizer : Tokenizer for model
25
- token_weighter : Class that weighs tokens
26
- """
27
-
28
- self.model = model
29
- self.tokenizer = tokenizer
30
- if token_weighter is None:
31
- token_weighter = TokenWeighter(tokenizer)
32
- self.token_weighter = token_weighter
33
- self._logger = logging.getLogger(__name__)
34
-
35
- self.tokenized_rhymes_ = None
36
- self.position_probas_ = None
37
-
38
- # Easy access.
39
- self.comma_token_id = self.tokenizer.encode(",", add_special_tokens=False)[0]
40
- self.period_token_id = self.tokenizer.encode(".", add_special_tokens=False)[0]
41
- self.mask_token_id = self.tokenizer.mask_token_id
42
-
43
- def start(self, query: str, rhyme_words: List[str]) -> None:
44
- """Start the sentence generator.
45
-
46
- Parameters
47
- ----------
48
- query : Seed sentence
49
- rhyme_words : Rhyme words for next sentence
50
- """
51
- # TODO: What if no content?
52
- self._logger.info("Got sentence %s", query)
53
- tokenized_rhymes = [
54
- self._initialize_rhymes(query, rhyme_word) for rhyme_word in rhyme_words
55
- ]
56
- # Make same length.
57
- self.tokenized_rhymes_ = tf.keras.preprocessing.sequence.pad_sequences(
58
- tokenized_rhymes, padding="post", value=self.tokenizer.pad_token_id
59
- )
60
- p = self.tokenized_rhymes_ == self.tokenizer.mask_token_id
61
- self.position_probas_ = p / p.sum(1).reshape(-1, 1)
62
-
63
- def _initialize_rhymes(self, query: str, rhyme_word: str) -> List[int]:
64
- """Initialize the rhymes.
65
-
66
- * Tokenize input
67
- * Append a comma if the sentence does not end in it (might add better predictions as it
68
- shows the two sentence parts are related)
69
- * Make second line as long as the original
70
- * Add a period
71
-
72
- Parameters
73
- ----------
74
- query : First line
75
- rhyme_word : Last word for second line
76
-
77
- Returns
78
- -------
79
- Tokenized rhyme lines
80
- """
81
-
82
- query_token_ids = self.tokenizer.encode(query, add_special_tokens=False)
83
- rhyme_word_token_ids = self.tokenizer.encode(
84
- rhyme_word, add_special_tokens=False
85
- )
86
-
87
- if query_token_ids[-1] != self.comma_token_id:
88
- query_token_ids.append(self.comma_token_id)
89
-
90
- magic_correction = len(rhyme_word_token_ids) + 1 # 1 for comma
91
- return (
92
- query_token_ids
93
- + [self.tokenizer.mask_token_id] * (len(query_token_ids) - magic_correction)
94
- + rhyme_word_token_ids
95
- + [self.period_token_id]
96
- )
97
-
98
- def mutate(self):
99
- """Mutate the current rhymes.
100
-
101
- Returns
102
- -------
103
- Mutated rhymes
104
- """
105
- self.tokenized_rhymes_ = self._mutate(
106
- self.tokenized_rhymes_, self.position_probas_, self.token_weighter.proba
107
- )
108
-
109
- rhymes = []
110
- for i in range(len(self.tokenized_rhymes_)):
111
- rhymes.append(
112
- self.tokenizer.convert_tokens_to_string(
113
- self.tokenizer.convert_ids_to_tokens(
114
- self.tokenized_rhymes_[i], skip_special_tokens=True
115
- )
116
- )
117
- )
118
- return rhymes
119
-
120
- def _mutate(
121
- self,
122
- tokenized_rhymes: np.ndarray,
123
- position_probas: np.ndarray,
124
- token_id_probas: np.ndarray,
125
- ) -> np.ndarray:
126
-
127
- replacements = []
128
- for i in range(tokenized_rhymes.shape[0]):
129
- mask_idx, masked_token_ids = self._mask_token(
130
- tokenized_rhymes[i], position_probas[i]
131
- )
132
- tokenized_rhymes[i] = masked_token_ids
133
- replacements.append(mask_idx)
134
-
135
- predictions = self._predict_masked_tokens(tokenized_rhymes)
136
-
137
- for i, token_ids in enumerate(tokenized_rhymes):
138
- replace_ix = replacements[i]
139
- token_ids[replace_ix] = self._draw_replacement(
140
- predictions[i], token_id_probas, replace_ix
141
- )
142
- tokenized_rhymes[i] = token_ids
143
-
144
- return tokenized_rhymes
145
-
146
- def _mask_token(self, token_ids, position_probas):
147
- """Mask line and return index to update."""
148
- token_ids = self._mask_repeats(token_ids, position_probas)
149
- ix = self._locate_mask(token_ids, position_probas)
150
- token_ids[ix] = self.mask_token_id
151
- return ix, token_ids
152
-
153
- def _locate_mask(self, token_ids, position_probas):
154
- """Update masks or a random token."""
155
- if self.mask_token_id in token_ids:
156
- # Already masks present, just return the last.
157
- # We used to return thee first but this returns worse predictions.
158
- return np.where(token_ids == self.tokenizer.mask_token_id)[0][-1]
159
- return np.random.choice(range(len(position_probas)), p=position_probas)
160
-
161
- def _mask_repeats(self, token_ids, position_probas):
162
- """Repeated tokens are generally of less quality."""
163
- repeats = [
164
- ii for ii, ids in enumerate(pairwise(token_ids[:-2])) if ids[0] == ids[1]
165
- ]
166
- for ii in repeats:
167
- if position_probas[ii] > 0:
168
- token_ids[ii] = self.mask_token_id
169
- if position_probas[ii + 1] > 0:
170
- token_ids[ii + 1] = self.mask_token_id
171
- return token_ids
172
-
173
- def _predict_masked_tokens(self, tokenized_rhymes):
174
- return self.model(tf.constant(tokenized_rhymes))[0]
175
-
176
- def _draw_replacement(self, predictions, token_probas, replace_ix):
177
- """Get probability, weigh and draw."""
178
- # TODO (HG): Can't we softmax when calling the model?
179
- probas = tf.nn.softmax(predictions[replace_ix]).numpy() * token_probas
180
- probas /= probas.sum()
181
- return np.random.choice(range(len(probas)), p=probas)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rhyme_with_ai/token_weighter.py DELETED
@@ -1,17 +0,0 @@
1
- import numpy as np
2
-
3
-
4
- class TokenWeighter:
5
- def __init__(self, tokenizer):
6
- self.tokenizer_ = tokenizer
7
- self.proba = self.get_token_proba()
8
-
9
- def get_token_proba(self):
10
- valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
11
- return valid_token_mask
12
-
13
- def _filter_short_partial(self, vocab):
14
- valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
15
- is_valid = np.zeros(len(vocab.keys()))
16
- is_valid[valid_token_ids] = 1
17
- return is_valid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rhyme_with_ai/utils.py DELETED
@@ -1,49 +0,0 @@
1
- import itertools
2
- import string
3
-
4
-
5
- def color_new_words(new: str, old: str, color: str = "#eefa66") -> str:
6
- """Color new words in strings with a span."""
7
-
8
- def find_diff(new_, old_):
9
- return [ii for ii, (n, o) in enumerate(zip(new_, old_)) if n != o]
10
-
11
- new_words = new.split()
12
- old_words = old.split()
13
- forward = find_diff(new_words, old_words)
14
- backward = find_diff(new_words[::-1], old_words[::-1])
15
-
16
- if not forward or not backward:
17
- # No difference
18
- return new
19
-
20
- start, end = forward[0], len(new_words) - backward[0]
21
- return (
22
- " ".join(new_words[:start])
23
- + " "
24
- + f'<span style="background-color: {color}">'
25
- + " ".join(new_words[start:end])
26
- + "</span>"
27
- + " "
28
- + " ".join(new_words[end:])
29
- )
30
-
31
-
32
- def find_last_word(s):
33
- """Find the last word in a string."""
34
- # Note: will break on \n, \r, etc.
35
- alpha_only_sentence = "".join([c for c in s if (c.isalpha() or (c == " "))]).strip()
36
- return alpha_only_sentence.split()[-1]
37
-
38
-
39
- def pairwise(iterable):
40
- """s -> (s0,s1), (s1,s2), (s2, s3), ..."""
41
- # https://stackoverflow.com/questions/5434891/iterate-a-list-as-pair-current-next-in-python
42
- a, b = itertools.tee(iterable)
43
- next(b, None)
44
- return zip(a, b)
45
-
46
-
47
- def sanitize(s):
48
- """Remove punctuation from a string."""
49
- return s.translate(str.maketrans("", "", string.punctuation))