philipp-zettl commited on
Commit
d7eff13
1 Parent(s): 082fc10

Upload 2 files

Browse files
Files changed (2) hide show
  1. optimization.py +66 -0
  2. text.py +130 -0
optimization.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ from itertools import chain
3
+ import math
4
+ import torch
5
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
6
+
7
+
8
+ def ngrams(sequence, n):
9
+ return [tuple(sequence[i:i+n]) for i in range(len(sequence)-n+1)]
10
+
11
+ def count_ngrams(sequence, max_n):
12
+ counts = Counter()
13
+ for n in range(1, max_n + 1):
14
+ counts.update(ngrams(sequence, n))
15
+ return counts
16
+
17
+ def self_bleu(outputs):
18
+ smoothing_function = SmoothingFunction().method1
19
+ scores = []
20
+ for i in range(len(outputs)):
21
+ references = outputs[:i] + outputs[i+1:]
22
+ # Avoid calculating BLEU score for empty references
23
+ if references:
24
+ scores.append(sentence_bleu(references, outputs[i], smoothing_function=smoothing_function))
25
+ # If all references are empty, return a default value
26
+ if not scores:
27
+ return 0
28
+ return sum(scores) / len(scores)
29
+
30
+ def dist_n(outputs, n):
31
+ all_ngrams = list(chain(*[ngrams(output, n) for output in outputs]))
32
+ unique_ngrams = set(all_ngrams)
33
+ return len(unique_ngrams) / len(all_ngrams) if all_ngrams else 0
34
+
35
+ def perplexity(model, tokenizer, texts):
36
+ encodings = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
37
+ max_length = model.config.n_positions
38
+ stride = 512
39
+ lls = []
40
+ for i in range(0, encodings.input_ids.size(1), stride):
41
+ begin_loc = max(i + stride - max_length, 0)
42
+ end_loc = i + stride
43
+ trg_len = end_loc - i
44
+ input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device)
45
+ target_ids = input_ids.clone()
46
+ target_ids[:, :-trg_len] = -100
47
+
48
+ with torch.no_grad():
49
+ outputs = model(input_ids, labels=target_ids)
50
+ log_likelihood = outputs.loss * trg_len
51
+ lls.append(log_likelihood)
52
+
53
+ ppl = torch.exp(torch.stack(lls).sum() / end_loc)
54
+ return ppl.item()
55
+
56
+ def js_divergence(p, q):
57
+ def kl_divergence(p, q):
58
+ return sum(p[i] * math.log(p[i] / q[i]) for i in range(len(p)) if p[i] != 0 and q[i] != 0)
59
+
60
+ p_norm = [float(i)/sum(p) for i in p]
61
+ q_norm = [float(i)/sum(q) for i in q]
62
+
63
+ m = [(p_norm[i] + q_norm[i]) / 2 for i in range(len(p_norm))]
64
+
65
+ return (kl_divergence(p_norm, m) + kl_divergence(q_norm, m)) / 2
66
+
text.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from markdownify import markdownify as md
2
+ from bs4 import BeautifulSoup as BS
3
+ from IPython.display import display, Markdown
4
+ from urllib.parse import urljoin
5
+ from newspaper import Article
6
+ import re
7
+ import markdown
8
+
9
+
10
+ def clean(s):
11
+ s = s.replace("\t", "\\t")
12
+ s = s.replace("\n", "\\n")
13
+ return s
14
+
15
+ class DocTree:
16
+ def __init__(self, content):
17
+ self.content = content
18
+ self.max_depth = 6
19
+
20
+ def get_sections(self, *location_ids):
21
+ out = self.content
22
+ for id_ in location_ids:
23
+ out = out[id_]
24
+ return out
25
+
26
+ def merge_sections(self, elems):
27
+ if not isinstance(elems[0], list):
28
+ return '\n\n '.join(elems)
29
+ out = []
30
+ for e in elems:
31
+ out.append(self.merge_sections(e))
32
+ return '\n\n '.join(map(clean, out))
33
+
34
+ def get_merged_sections(self, *location_ids):
35
+ return [self.merge_sections(s) for s in self.get_sections(*location_ids)]
36
+
37
+ def as_markdown(self, content):
38
+ return md(content)
39
+
40
+ def get_sections_by_depth(self, depth):
41
+ return self._get_sections_by_depth(self.content, depth)
42
+
43
+ @staticmethod
44
+ def _get_sections_by_depth(content, depth):
45
+ """Returns a list of merged sections at a specific depth"""
46
+ if depth == 0:
47
+ return content
48
+ out = []
49
+ for elem in content:
50
+ out += DocTree._get_sections_by_depth(elem, depth - 1)
51
+ return out
52
+
53
+
54
+ def fix_relative_links(url, article_content):
55
+ if 'http' in url:
56
+ base_url = '/'.join(url.split('/')[:3])
57
+ else:
58
+ base_url = url.split('/')
59
+ pat = re.compile(r'\[(.*?)\]\((.*?)\)', flags=re.IGNORECASE)
60
+ res = pat.findall(article_content)
61
+ if res:
62
+ for g in res:
63
+ url = urljoin(base_url, g[1]) if g[1].startswith('/') else g[1]
64
+ article_content = article_content.replace(f'[{g[0]}]({g[1]})', f'[{g[0]}]({url})')
65
+ else:print('not found')
66
+ return article_content
67
+
68
+
69
+ def extract_article(url):
70
+ article = Article(url)
71
+ article.download()
72
+ article.parse()
73
+ return article
74
+
75
+
76
+ def select_content(html_code, elem_class, class_name):
77
+ print(f'Calling select_content with {elem_class}, {class_name}')
78
+ if class_name.startswith('.'):
79
+ class_name = class_name[1:]
80
+ elem_id = None
81
+ elif class_name.startswith('#'):
82
+ elem_id = class_name[1:]
83
+ class_name = None
84
+ else:
85
+ elem_id = None
86
+ class_name = None
87
+ return md(str(BS(html_code, features="lxml").find(elem_class, class_=class_name, id=elem_id)))
88
+
89
+
90
+ def split_by_heading(html_content, _i):
91
+ if _i >= 7:
92
+ return html_content
93
+ elems = []
94
+ for idx, elem in enumerate([i for i in html_content.split(f'<h{_i}') if i]):
95
+ if idx > 0 or elem.startswith('>'):
96
+ elem = f'<h{_i}{elem}'
97
+ elems.append(split_by_heading(elem, _i+1))
98
+ return elems
99
+
100
+ def doctree_from_url(url, elem_class='div', class_name='article-body'):
101
+ article = extract_article(url)
102
+ # convert to MD to handle splitting better
103
+ article_content = select_content(article.html, elem_class, class_name)
104
+ article_content = (f"# {article.title}\n\n" + article_content).replace('\n\n', '\n').replace('#', '%%@@%%')
105
+ # fix relative website links
106
+ article_content = fix_relative_links(url, article_content)
107
+ # convert back to HTML
108
+ html_content = markdown.markdown(article_content).replace('%%@@%%', '#')
109
+ doc_tree = DocTree(split_by_heading(html_content, 1))
110
+
111
+ #assert doc_tree.merge_sections(doc_tree.get_sections(0)).replace('\n', '').replace(html_content.replace('\n', ''), '') == '', 'Document inconsistent. Manual adjustments required.'
112
+ return doc_tree
113
+
114
+
115
+ def get_selectors_for_class(url, elem_class):
116
+ article = extract_article(url)
117
+
118
+ html_content = article.html
119
+ soup = BS(html_content, features="lxml")
120
+ classes = set()
121
+ ids = set()
122
+ for elem in soup.find_all(elem_class):
123
+ if elem.get('class'):
124
+ for c in elem.get('class'):
125
+ classes |= {f".{c}"}
126
+ if elem.get('id'):
127
+ for c in elem.get('id'):
128
+ ids |= {f"#{c}"}
129
+
130
+ return ids | classes