Spaces:
Sleeping
Sleeping
lukestanley
commited on
Commit
•
759e510
1
Parent(s):
1e622b4
WIP spicy Jigsaw - Wikipedia talk page dataset review and scoring
Browse files
learn.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
import pandas as pd
|
3 |
+
from datasets import load_dataset
|
4 |
+
from detoxify import Detoxify
|
5 |
+
predict_model = Detoxify('original-small')
|
6 |
+
dataset = load_dataset("tasksource/jigsaw")
|
7 |
+
|
8 |
+
train_data = dataset['train']
|
9 |
+
print('length',len(train_data)) # length 159571
|
10 |
+
print(train_data[0]) # {'id': '0000997932d777bf', 'comment_text': "Explanation\nWhy the edits made under my username Hardcore Metallica Fan were reverted? They weren't vandalisms, just closure on some GAs after I voted at New York Dolls FAC. And please don't remove the template from the talk page since I'm retired now.89.205.38.27", 'toxic': 0, 'severe_toxic': 0, 'obscene': 0, 'threat': 0, 'insult': 0, 'identity_hate': 0}
|
11 |
+
|
12 |
+
small_subset = train_data[:2000]
|
13 |
+
|
14 |
+
predict_model.predict("You suck, that is not Markdown!") # Also accepts an array of strings, returning an single dict of arrays of predictions.
|
15 |
+
# Returns:
|
16 |
+
{'toxicity': 0.98870254,
|
17 |
+
'severe_toxicity': 0.087154716,
|
18 |
+
'obscene': 0.93440753,
|
19 |
+
'threat': 0.0032278204,
|
20 |
+
'insult': 0.7787105,
|
21 |
+
'identity_attack': 0.007936229}
|
22 |
+
|
23 |
+
# %%
|
24 |
+
import asyncio
|
25 |
+
import json
|
26 |
+
import time
|
27 |
+
import os
|
28 |
+
import hashlib
|
29 |
+
from functools import wraps
|
30 |
+
|
31 |
+
|
32 |
+
_in_memory_cache = {}
|
33 |
+
|
34 |
+
def handle_cache(prefix, func, *args, _result=None, **kwargs):
|
35 |
+
# Generate a key based on function name and arguments
|
36 |
+
key = f"{func.__name__}_{args}_{kwargs}"
|
37 |
+
hashed_key = hashlib.sha1(key.encode()).hexdigest()
|
38 |
+
cache_filename = f"{prefix}_{hashed_key}.json"
|
39 |
+
|
40 |
+
# Check the in-memory cache first
|
41 |
+
if key in _in_memory_cache:
|
42 |
+
return _in_memory_cache[key]
|
43 |
+
|
44 |
+
# Check if cache file exists and read data
|
45 |
+
if os.path.exists(cache_filename):
|
46 |
+
with open(cache_filename, 'r') as file:
|
47 |
+
#print("Reading from cache file with prefix", prefix)
|
48 |
+
_in_memory_cache[key] = json.load(file)
|
49 |
+
return _in_memory_cache[key]
|
50 |
+
|
51 |
+
# If result is not provided (for sync functions), compute it
|
52 |
+
if _result is None:
|
53 |
+
_result = func(*args, **kwargs)
|
54 |
+
|
55 |
+
# Update the in-memory cache and write it to the file
|
56 |
+
_in_memory_cache[key] = _result
|
57 |
+
with open(cache_filename, 'w') as file:
|
58 |
+
json.dump(_result, file)
|
59 |
+
|
60 |
+
return _result
|
61 |
+
|
62 |
+
|
63 |
+
def acache(prefix):
|
64 |
+
def decorator(func):
|
65 |
+
@wraps(func)
|
66 |
+
async def wrapper(*args, **kwargs):
|
67 |
+
# Generate a key based on function name and arguments
|
68 |
+
key = f"{func.__name__}_{args}_{kwargs}"
|
69 |
+
hashed_key = hashlib.sha1(key.encode()).hexdigest()
|
70 |
+
cache_filename = f"{prefix}_{hashed_key}.json"
|
71 |
+
|
72 |
+
# Check the in-memory cache first
|
73 |
+
if key in _in_memory_cache:
|
74 |
+
return _in_memory_cache[key]
|
75 |
+
|
76 |
+
# Check if cache file exists and read data
|
77 |
+
if os.path.exists(cache_filename):
|
78 |
+
with open(cache_filename, 'r') as file:
|
79 |
+
_in_memory_cache[key] = json.load(file)
|
80 |
+
return _in_memory_cache[key]
|
81 |
+
|
82 |
+
# Await the function call and get the result
|
83 |
+
print("Computing result for async function")
|
84 |
+
result = await func(*args, **kwargs)
|
85 |
+
|
86 |
+
# Update the in-memory cache and write it to the file
|
87 |
+
_in_memory_cache[key] = result
|
88 |
+
with open(cache_filename, 'w') as file:
|
89 |
+
json.dump(result, file)
|
90 |
+
|
91 |
+
return result
|
92 |
+
|
93 |
+
return wrapper
|
94 |
+
return decorator
|
95 |
+
|
96 |
+
|
97 |
+
def cache(prefix):
|
98 |
+
def decorator(func):
|
99 |
+
@wraps(func)
|
100 |
+
def wrapper(*args, **kwargs):
|
101 |
+
# Direct call to the shared cache handling function
|
102 |
+
return handle_cache(prefix, func, *args, **kwargs)
|
103 |
+
return wrapper
|
104 |
+
return decorator
|
105 |
+
|
106 |
+
def timeit(func):
|
107 |
+
@wraps(func)
|
108 |
+
async def async_wrapper(*args, **kwargs):
|
109 |
+
start_time = time.time()
|
110 |
+
result = await func(*args, **kwargs) # Awaiting the async function
|
111 |
+
end_time = time.time()
|
112 |
+
print(f"{func.__name__} took {end_time - start_time:.1f} seconds to run.")
|
113 |
+
return result
|
114 |
+
|
115 |
+
@wraps(func)
|
116 |
+
def sync_wrapper(*args, **kwargs):
|
117 |
+
start_time = time.time()
|
118 |
+
result = func(*args, **kwargs) # Calling the sync function
|
119 |
+
end_time = time.time()
|
120 |
+
print(f"{func.__name__} took {end_time - start_time:.1f} seconds to run.")
|
121 |
+
return result
|
122 |
+
|
123 |
+
if asyncio.iscoroutinefunction(func):
|
124 |
+
return async_wrapper
|
125 |
+
else:
|
126 |
+
return sync_wrapper
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
# %%
|
131 |
+
|
132 |
+
@cache("toxicity")
|
133 |
+
def cached_toxicity_prediction(comments):
|
134 |
+
data = predict_model.predict(comments)
|
135 |
+
return data
|
136 |
+
|
137 |
+
def predict_toxicity(comments, batch_size=4):
|
138 |
+
"""
|
139 |
+
Predicts toxicity scores for a list of comments.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
- comments: List of comment texts.
|
143 |
+
- batch_size: Size of batches for prediction to manage memory usage.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
A DataFrame with the original comments and their predicted toxicity scores.
|
147 |
+
"""
|
148 |
+
results = {'comment_text': [], 'toxicity': [], 'severe_toxicity': [], 'obscene': [], 'threat': [], 'insult': [], 'identity_attack': []}
|
149 |
+
for i in range(0, len(comments), batch_size):
|
150 |
+
batch_comments = comments[i:i+batch_size]
|
151 |
+
predictions = cached_toxicity_prediction(batch_comments)
|
152 |
+
# We convert the JSON serializable data back to a DataFrame:
|
153 |
+
results['comment_text'].extend(batch_comments)
|
154 |
+
for key in predictions.keys():
|
155 |
+
results[key].extend(predictions[key])
|
156 |
+
return pd.DataFrame(results)
|
157 |
+
|
158 |
+
# Predict toxicity scores for the small subset of comments:
|
159 |
+
#small_subset_predictions = predict_toxicity(small_subset['comment_text'][4])
|
160 |
+
# Let's just try out 4 comments with cached_toxicity_prediction:
|
161 |
+
small_subset['comment_text'][0:1]
|
162 |
+
|
163 |
+
# %%
|
164 |
+
small_subset_predictions=predict_toxicity(small_subset['comment_text'][0:200])
|
165 |
+
|
166 |
+
# %%
|
167 |
+
small_subset_predictions
|
168 |
+
|
169 |
+
# %%
|
170 |
+
def filter_comments(dataframe, toxicity_threshold=0.2, severe_toxicity_threshold=0.4):
|
171 |
+
"""
|
172 |
+
Filters comments based on specified thresholds for toxicity, severe toxicity.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
- dataframe: DataFrame containing comments and their toxicity scores.
|
176 |
+
- toxicity_threshold: Toxicity score threshold.
|
177 |
+
- severe_toxicity_threshold: Severe toxicity score threshold.
|
178 |
+
- identity_attack_threshold: Identity attack score threshold.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
DataFrame filtered based on the specified thresholds.
|
182 |
+
"""
|
183 |
+
identity_attack_threshold = 0.5
|
184 |
+
insult_threshold = 0.3
|
185 |
+
obscene_threshold = 0.6
|
186 |
+
threat_threshold = 0.3
|
187 |
+
filtered_df = dataframe[
|
188 |
+
(dataframe['toxicity'] >= toxicity_threshold) &
|
189 |
+
#(dataframe['toxicity'] < 1.0) & # Ensure comments are spicy but not 100% toxic
|
190 |
+
(dataframe['severe_toxicity'] < severe_toxicity_threshold) &
|
191 |
+
(dataframe['identity_attack'] < identity_attack_threshold) &
|
192 |
+
(dataframe['insult'] < insult_threshold) &
|
193 |
+
(dataframe['obscene'] < obscene_threshold) &
|
194 |
+
(dataframe['threat'] < threat_threshold)
|
195 |
+
|
196 |
+
]
|
197 |
+
return filtered_df
|
198 |
+
|
199 |
+
spicy_comments = filter_comments(small_subset_predictions)
|
200 |
+
|
201 |
+
|
202 |
+
# Lets sort spicy comments by combined toxicity score:
|
203 |
+
spicy_comments.sort_values(by=['toxicity', 'severe_toxicity'], ascending=True, inplace=True)
|
204 |
+
|
205 |
+
# Print the spicy comments comment_text and their toxicity scores as a formatted string:
|
206 |
+
for index, row in spicy_comments.iterrows():
|
207 |
+
print(f"Comment: `{row['comment_text']}` \n Toxiciy: {(row['toxicity'] + row['severe_toxicity']) / 2 * 100:.0f}% \n")
|