hobs
commited on
Commit
•
0647bb4
1
Parent(s):
5c11d69
load categories from json
Browse files- app.py +108 -74
- categories.json +1 -0
app.py
CHANGED
@@ -2,17 +2,117 @@
|
|
2 |
|
3 |
import gradio as gr
|
4 |
|
5 |
-
import
|
6 |
from pathlib import Path
|
7 |
# import random
|
8 |
# import time
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
|
12 |
-
import pandas as pd
|
13 |
-
from nlpia2.init import SRC_DATA_DIR, maybe_download
|
14 |
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
name_char_vocab_size = len(ASCII_NAME_CHARS) + 1 # Plus EOS marker
|
18 |
|
@@ -31,49 +131,10 @@ char2i = {c: i for i, c in enumerate(ASCII_NAME_CHARS)}
|
|
31 |
|
32 |
print(f'asciify("O’Néàl") => {asciify("O’Néàl")}')
|
33 |
|
34 |
-
# Build the category_lines dictionary, a list of names per language
|
35 |
-
category_lines = {}
|
36 |
-
all_categories = []
|
37 |
-
labeled_lines = []
|
38 |
-
categories = []
|
39 |
-
for filepath in find_files(SRC_DATA_DIR / 'names', '*.txt'):
|
40 |
-
filename = Path(filepath).name
|
41 |
-
filepath = maybe_download(filename=Path('names') / filename)
|
42 |
-
with filepath.open() as fin:
|
43 |
-
lines = [asciify(line.rstrip()) for line in fin]
|
44 |
-
category = Path(filename).with_suffix('')
|
45 |
-
categories.append(category)
|
46 |
-
labeled_lines += list(zip(lines, [category] * len(lines)))
|
47 |
|
|
|
48 |
n_categories = len(categories)
|
49 |
|
50 |
-
df = pd.DataFrame(labeled_lines, columns=('name', 'category'))
|
51 |
-
|
52 |
-
|
53 |
-
def readLines(filename):
|
54 |
-
lines = open(filename, encoding='utf-8').read().strip().split('\n')
|
55 |
-
return [asciify(line) for line in lines]
|
56 |
-
|
57 |
-
|
58 |
-
for filename in find_files(path='data/names', pattern='*.txt'):
|
59 |
-
category = os.path.splitext(os.path.basename(filename))[0]
|
60 |
-
all_categories.append(category)
|
61 |
-
lines = readLines(filename)
|
62 |
-
category_lines[category] = lines
|
63 |
-
|
64 |
-
n_categories = len(all_categories)
|
65 |
-
|
66 |
-
|
67 |
-
######################################################################
|
68 |
-
# Now we have ``category_lines``, a dictionary mapping each category
|
69 |
-
# (language) to a list of lines (names). We also kept track of
|
70 |
-
# ``all_categories`` (just a list of languages) and ``n_categories`` for
|
71 |
-
# later reference.
|
72 |
-
#
|
73 |
-
|
74 |
-
print(category_lines['Italian'][:5])
|
75 |
-
|
76 |
-
|
77 |
######################################################################
|
78 |
# Turning Names into Tensors
|
79 |
# --------------------------
|
@@ -117,33 +178,6 @@ def encode_one_hot_seq(line):
|
|
117 |
return tensor
|
118 |
|
119 |
|
120 |
-
print(encode_one_hot_vec('A'))
|
121 |
-
|
122 |
-
print(encode_one_hot_seq('Abe').size())
|
123 |
-
|
124 |
-
|
125 |
-
######################################################################
|
126 |
-
# Creating the Network
|
127 |
-
# ====================
|
128 |
-
#
|
129 |
-
# Before autograd, creating a recurrent neural network in Torch involved
|
130 |
-
# cloning the parameters of a layer over several timesteps. The layers
|
131 |
-
# held hidden state and gradients which are now entirely handled by the
|
132 |
-
# graph itself. This means you can implement a RNN in a very "pure" way,
|
133 |
-
# as regular feed-forward layers.
|
134 |
-
#
|
135 |
-
# This RNN module (mostly copied from `the PyTorch for Torch users
|
136 |
-
# tutorial <https://pytorch.org/tutorials/beginner/former_torchies/
|
137 |
-
# nn_tutorial.html#example-2-recurrent-net>`__)
|
138 |
-
# is just 2 linear layers which operate on an input and hidden state, with
|
139 |
-
# a LogSoftmax layer after the output.
|
140 |
-
#
|
141 |
-
# .. figure:: https://i.imgur.com/Z2xbySO.png
|
142 |
-
# :alt:
|
143 |
-
#
|
144 |
-
#
|
145 |
-
|
146 |
-
|
147 |
class RNN(nn.Module):
|
148 |
def __init__(self, input_size, hidden_size, output_size):
|
149 |
super(RNN, self).__init__()
|
@@ -178,7 +212,7 @@ output, next_hidden = rnn(input, hidden)
|
|
178 |
def categoryFromOutput(output):
|
179 |
top_n, top_i = output.topk(1)
|
180 |
category_i = top_i[0].item()
|
181 |
-
return
|
182 |
|
183 |
|
184 |
def output_from_str(s):
|
@@ -222,8 +256,8 @@ def predict(input_line, n_predictions=3):
|
|
222 |
for i in range(n_predictions):
|
223 |
value = topv[0][i].item()
|
224 |
category_index = topi[0][i].item()
|
225 |
-
print('(%.2f) %s' % (value,
|
226 |
-
predictions.append([value,
|
227 |
|
228 |
|
229 |
predict('Dovesky')
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
|
5 |
+
import json
|
6 |
from pathlib import Path
|
7 |
# import random
|
8 |
# import time
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
|
|
|
|
|
12 |
|
13 |
+
import string
|
14 |
+
import unicodedata
|
15 |
+
from unidecode import unidecode
|
16 |
+
|
17 |
+
|
18 |
+
ASCII_LETTERS = string.ascii_letters
|
19 |
+
ASCII_PRINTABLE = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c'
|
20 |
+
ASCII_PRINTABLE_COMMON = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r'
|
21 |
+
|
22 |
+
ASCII_VERTICAL_TAB = '\x0b'
|
23 |
+
ASCII_PAGE_BREAK = '\x0c'
|
24 |
+
ASCII_ALL = ''.join(chr(i) for i in range(0, 128)) # ASCII_PRINTABLE
|
25 |
+
ASCII_DIGITS = string.digits
|
26 |
+
ASCII_IMPORTANT_PUNCTUATION = " .?!,;'-=+)(:"
|
27 |
+
ASCII_NAME_PUNCTUATION = " .,;'-"
|
28 |
+
ASCII_NAME_CHARS = set(ASCII_LETTERS + ASCII_NAME_PUNCTUATION)
|
29 |
+
ASCII_IMPORTANT_CHARS = set(ASCII_LETTERS + ASCII_IMPORTANT_PUNCTUATION)
|
30 |
+
|
31 |
+
CURLY_SINGLE_QUOTES = '‘’`´'
|
32 |
+
STRAIGHT_SINGLE_QUOTES = "'" * len(CURLY_SINGLE_QUOTES)
|
33 |
+
CURLY_DOUBLE_QUOTES = '“”'
|
34 |
+
STRAIGHT_DOUBLE_QUOTES = '"' * len(CURLY_DOUBLE_QUOTES)
|
35 |
+
|
36 |
+
|
37 |
+
def normalize_newlines(s):
|
38 |
+
s = s.replace(ASCII_VERTICAL_TAB, '\n')
|
39 |
+
s = s.replace(ASCII_PAGE_BREAK, '\n\n')
|
40 |
+
|
41 |
+
|
42 |
+
class Asciifier:
|
43 |
+
""" Construct a function that filters out all non-ascii unicode characters
|
44 |
+
|
45 |
+
>>> test_str = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c'
|
46 |
+
>>> Asciifier(include='a b c 123XYZ')(test_str):
|
47 |
+
'123abcXYZ '
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
min_ord=1, max_ord=128,
|
53 |
+
exclude=None,
|
54 |
+
include=ASCII_PRINTABLE,
|
55 |
+
exclude_category='Mn',
|
56 |
+
normalize_quotes=True,
|
57 |
+
):
|
58 |
+
self.include = set(sorted(include or ASCII_PRINTABLE))
|
59 |
+
self._include = ''.join(sorted(self.include))
|
60 |
+
self.exclude = exclude or set()
|
61 |
+
self.exclude = set(sorted(exclude or []))
|
62 |
+
self._exclude = ''.join(self.exclude)
|
63 |
+
self.min_ord, self.max_ord = int(min_ord), int(max_ord or 128)
|
64 |
+
self.normalize_quotes = normalize_quotes
|
65 |
+
|
66 |
+
if self.min_ord:
|
67 |
+
self.include = set(c for c in self.include if ord(c) >= self.min_ord)
|
68 |
+
if self.max_ord:
|
69 |
+
self.include = set(c for c in self._include if ord(c) <= self.max_ord)
|
70 |
+
if exclude_category:
|
71 |
+
self.include = set(
|
72 |
+
c for c in self._include if unicodedata.category(c) != exclude_category)
|
73 |
+
|
74 |
+
self.vocab = sorted(self.include - self.exclude)
|
75 |
+
self._vocab = ''.join(self.vocab)
|
76 |
+
self.char2i = {c: i for (i, c) in enumerate(self._vocab)}
|
77 |
+
|
78 |
+
self._translate_from = self._vocab
|
79 |
+
self._translate_to = self._translate_from
|
80 |
+
|
81 |
+
# FIXME: self.normalize_quotes is accomplished by unidecode.unidecode!!
|
82 |
+
# ’->' ‘->' “->" ”->"
|
83 |
+
if self.normalize_quotes:
|
84 |
+
trans_table = str.maketrans(
|
85 |
+
CURLY_SINGLE_QUOTES + CURLY_DOUBLE_QUOTES,
|
86 |
+
STRAIGHT_SINGLE_QUOTES + STRAIGHT_DOUBLE_QUOTES)
|
87 |
+
self._translate_to = self._translate_to.translate(trans_table)
|
88 |
+
# print(self._translate_to)
|
89 |
+
|
90 |
+
# eliminate any non-translations (if from == to)
|
91 |
+
self._translate_from_filtered = ''
|
92 |
+
self._translate_to_filtered = ''
|
93 |
+
|
94 |
+
for c1, c2 in zip(self._translate_from, self._translate_to):
|
95 |
+
if c1 == c2:
|
96 |
+
continue
|
97 |
+
else:
|
98 |
+
self._translate_from_filtered += c1
|
99 |
+
self._translate_to_filtered += c2
|
100 |
+
|
101 |
+
self._translate_del = ''
|
102 |
+
for c in ASCII_ALL:
|
103 |
+
if c not in self.vocab:
|
104 |
+
self._translate_del += c
|
105 |
+
|
106 |
+
self._translate_from = self._translate_from_filtered
|
107 |
+
self._translate_to = self._translate_to_filtered
|
108 |
+
self.translation_table = str.maketrans(
|
109 |
+
self._translate_from,
|
110 |
+
self._translate_to,
|
111 |
+
self._translate_del)
|
112 |
+
|
113 |
+
def __call__(self, text):
|
114 |
+
return unidecode(unicodedata.normalize('NFD', text)).translate(self.translation_table)
|
115 |
+
|
116 |
|
117 |
name_char_vocab_size = len(ASCII_NAME_CHARS) + 1 # Plus EOS marker
|
118 |
|
|
|
131 |
|
132 |
print(f'asciify("O’Néàl") => {asciify("O’Néàl")}')
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
+
categories = json.load(open('categories.json'))
|
136 |
n_categories = len(categories)
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
######################################################################
|
139 |
# Turning Names into Tensors
|
140 |
# --------------------------
|
|
|
178 |
return tensor
|
179 |
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
class RNN(nn.Module):
|
182 |
def __init__(self, input_size, hidden_size, output_size):
|
183 |
super(RNN, self).__init__()
|
|
|
212 |
def categoryFromOutput(output):
|
213 |
top_n, top_i = output.topk(1)
|
214 |
category_i = top_i[0].item()
|
215 |
+
return categories[category_i], category_i
|
216 |
|
217 |
|
218 |
def output_from_str(s):
|
|
|
256 |
for i in range(n_predictions):
|
257 |
value = topv[0][i].item()
|
258 |
category_index = topi[0][i].item()
|
259 |
+
print('(%.2f) %s' % (value, categories[category_index]))
|
260 |
+
predictions.append([value, categories[category_index]])
|
261 |
|
262 |
|
263 |
predict('Dovesky')
|
categories.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
["Arabic", "Irish", "Spanish", "French", "German", "English", "Korean", "Vietnamese", "Scottish", "Japanese", "Polish", "Greek", "Czech", "Italian", "Portuguese", "Russian", "Dutch", "Chinese"]
|