lukestanley commited on
Commit
759e510
1 Parent(s): 1e622b4

WIP spicy Jigsaw - Wikipedia talk page dataset review and scoring

Browse files
Files changed (1) hide show
  1. learn.py +207 -0
learn.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import pandas as pd
3
+ from datasets import load_dataset
4
+ from detoxify import Detoxify
5
+ predict_model = Detoxify('original-small')
6
+ dataset = load_dataset("tasksource/jigsaw")
7
+
8
+ train_data = dataset['train']
9
+ print('length',len(train_data)) # length 159571
10
+ print(train_data[0]) # {'id': '0000997932d777bf', 'comment_text': "Explanation\nWhy the edits made under my username Hardcore Metallica Fan were reverted? They weren't vandalisms, just closure on some GAs after I voted at New York Dolls FAC. And please don't remove the template from the talk page since I'm retired now.89.205.38.27", 'toxic': 0, 'severe_toxic': 0, 'obscene': 0, 'threat': 0, 'insult': 0, 'identity_hate': 0}
11
+
12
+ small_subset = train_data[:2000]
13
+
14
+ predict_model.predict("You suck, that is not Markdown!") # Also accepts an array of strings, returning an single dict of arrays of predictions.
15
+ # Returns:
16
+ {'toxicity': 0.98870254,
17
+ 'severe_toxicity': 0.087154716,
18
+ 'obscene': 0.93440753,
19
+ 'threat': 0.0032278204,
20
+ 'insult': 0.7787105,
21
+ 'identity_attack': 0.007936229}
22
+
23
+ # %%
24
+ import asyncio
25
+ import json
26
+ import time
27
+ import os
28
+ import hashlib
29
+ from functools import wraps
30
+
31
+
32
+ _in_memory_cache = {}
33
+
34
+ def handle_cache(prefix, func, *args, _result=None, **kwargs):
35
+ # Generate a key based on function name and arguments
36
+ key = f"{func.__name__}_{args}_{kwargs}"
37
+ hashed_key = hashlib.sha1(key.encode()).hexdigest()
38
+ cache_filename = f"{prefix}_{hashed_key}.json"
39
+
40
+ # Check the in-memory cache first
41
+ if key in _in_memory_cache:
42
+ return _in_memory_cache[key]
43
+
44
+ # Check if cache file exists and read data
45
+ if os.path.exists(cache_filename):
46
+ with open(cache_filename, 'r') as file:
47
+ #print("Reading from cache file with prefix", prefix)
48
+ _in_memory_cache[key] = json.load(file)
49
+ return _in_memory_cache[key]
50
+
51
+ # If result is not provided (for sync functions), compute it
52
+ if _result is None:
53
+ _result = func(*args, **kwargs)
54
+
55
+ # Update the in-memory cache and write it to the file
56
+ _in_memory_cache[key] = _result
57
+ with open(cache_filename, 'w') as file:
58
+ json.dump(_result, file)
59
+
60
+ return _result
61
+
62
+
63
+ def acache(prefix):
64
+ def decorator(func):
65
+ @wraps(func)
66
+ async def wrapper(*args, **kwargs):
67
+ # Generate a key based on function name and arguments
68
+ key = f"{func.__name__}_{args}_{kwargs}"
69
+ hashed_key = hashlib.sha1(key.encode()).hexdigest()
70
+ cache_filename = f"{prefix}_{hashed_key}.json"
71
+
72
+ # Check the in-memory cache first
73
+ if key in _in_memory_cache:
74
+ return _in_memory_cache[key]
75
+
76
+ # Check if cache file exists and read data
77
+ if os.path.exists(cache_filename):
78
+ with open(cache_filename, 'r') as file:
79
+ _in_memory_cache[key] = json.load(file)
80
+ return _in_memory_cache[key]
81
+
82
+ # Await the function call and get the result
83
+ print("Computing result for async function")
84
+ result = await func(*args, **kwargs)
85
+
86
+ # Update the in-memory cache and write it to the file
87
+ _in_memory_cache[key] = result
88
+ with open(cache_filename, 'w') as file:
89
+ json.dump(result, file)
90
+
91
+ return result
92
+
93
+ return wrapper
94
+ return decorator
95
+
96
+
97
+ def cache(prefix):
98
+ def decorator(func):
99
+ @wraps(func)
100
+ def wrapper(*args, **kwargs):
101
+ # Direct call to the shared cache handling function
102
+ return handle_cache(prefix, func, *args, **kwargs)
103
+ return wrapper
104
+ return decorator
105
+
106
+ def timeit(func):
107
+ @wraps(func)
108
+ async def async_wrapper(*args, **kwargs):
109
+ start_time = time.time()
110
+ result = await func(*args, **kwargs) # Awaiting the async function
111
+ end_time = time.time()
112
+ print(f"{func.__name__} took {end_time - start_time:.1f} seconds to run.")
113
+ return result
114
+
115
+ @wraps(func)
116
+ def sync_wrapper(*args, **kwargs):
117
+ start_time = time.time()
118
+ result = func(*args, **kwargs) # Calling the sync function
119
+ end_time = time.time()
120
+ print(f"{func.__name__} took {end_time - start_time:.1f} seconds to run.")
121
+ return result
122
+
123
+ if asyncio.iscoroutinefunction(func):
124
+ return async_wrapper
125
+ else:
126
+ return sync_wrapper
127
+
128
+
129
+
130
+ # %%
131
+
132
+ @cache("toxicity")
133
+ def cached_toxicity_prediction(comments):
134
+ data = predict_model.predict(comments)
135
+ return data
136
+
137
+ def predict_toxicity(comments, batch_size=4):
138
+ """
139
+ Predicts toxicity scores for a list of comments.
140
+
141
+ Args:
142
+ - comments: List of comment texts.
143
+ - batch_size: Size of batches for prediction to manage memory usage.
144
+
145
+ Returns:
146
+ A DataFrame with the original comments and their predicted toxicity scores.
147
+ """
148
+ results = {'comment_text': [], 'toxicity': [], 'severe_toxicity': [], 'obscene': [], 'threat': [], 'insult': [], 'identity_attack': []}
149
+ for i in range(0, len(comments), batch_size):
150
+ batch_comments = comments[i:i+batch_size]
151
+ predictions = cached_toxicity_prediction(batch_comments)
152
+ # We convert the JSON serializable data back to a DataFrame:
153
+ results['comment_text'].extend(batch_comments)
154
+ for key in predictions.keys():
155
+ results[key].extend(predictions[key])
156
+ return pd.DataFrame(results)
157
+
158
+ # Predict toxicity scores for the small subset of comments:
159
+ #small_subset_predictions = predict_toxicity(small_subset['comment_text'][4])
160
+ # Let's just try out 4 comments with cached_toxicity_prediction:
161
+ small_subset['comment_text'][0:1]
162
+
163
+ # %%
164
+ small_subset_predictions=predict_toxicity(small_subset['comment_text'][0:200])
165
+
166
+ # %%
167
+ small_subset_predictions
168
+
169
+ # %%
170
+ def filter_comments(dataframe, toxicity_threshold=0.2, severe_toxicity_threshold=0.4):
171
+ """
172
+ Filters comments based on specified thresholds for toxicity, severe toxicity.
173
+
174
+ Args:
175
+ - dataframe: DataFrame containing comments and their toxicity scores.
176
+ - toxicity_threshold: Toxicity score threshold.
177
+ - severe_toxicity_threshold: Severe toxicity score threshold.
178
+ - identity_attack_threshold: Identity attack score threshold.
179
+
180
+ Returns:
181
+ DataFrame filtered based on the specified thresholds.
182
+ """
183
+ identity_attack_threshold = 0.5
184
+ insult_threshold = 0.3
185
+ obscene_threshold = 0.6
186
+ threat_threshold = 0.3
187
+ filtered_df = dataframe[
188
+ (dataframe['toxicity'] >= toxicity_threshold) &
189
+ #(dataframe['toxicity'] < 1.0) & # Ensure comments are spicy but not 100% toxic
190
+ (dataframe['severe_toxicity'] < severe_toxicity_threshold) &
191
+ (dataframe['identity_attack'] < identity_attack_threshold) &
192
+ (dataframe['insult'] < insult_threshold) &
193
+ (dataframe['obscene'] < obscene_threshold) &
194
+ (dataframe['threat'] < threat_threshold)
195
+
196
+ ]
197
+ return filtered_df
198
+
199
+ spicy_comments = filter_comments(small_subset_predictions)
200
+
201
+
202
+ # Lets sort spicy comments by combined toxicity score:
203
+ spicy_comments.sort_values(by=['toxicity', 'severe_toxicity'], ascending=True, inplace=True)
204
+
205
+ # Print the spicy comments comment_text and their toxicity scores as a formatted string:
206
+ for index, row in spicy_comments.iterrows():
207
+ print(f"Comment: `{row['comment_text']}` \n Toxiciy: {(row['toxicity'] + row['severe_toxicity']) / 2 * 100:.0f}% \n")