|
import random |
|
|
|
import streamlit as st |
|
from bs4 import BeautifulSoup |
|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from transformers import pipeline |
|
from transformers_interpret import SequenceClassificationExplainer |
|
|
|
|
|
model_names_to_URLs = { |
|
'ml6team/distilbert-base-dutch-cased-toxic-comments': |
|
'https://huggingface.co/ml6team/distilbert-base-dutch-cased-toxic-comments', |
|
'ml6team/robbert-dutch-base-toxic-comments': |
|
'https://huggingface.co/ml6team/robbert-dutch-base-toxic-comments', |
|
} |
|
|
|
about_page_markdown = f"""# π€¬ Dutch Toxic Comment Detection Space |
|
|
|
Made by [ML6](https://ml6.eu/). |
|
|
|
Token attribution is performed using [transformers-interpret](https://github.com/cdpierse/transformers-interpret). |
|
""" |
|
|
|
regular_emojis = [ |
|
'π', 'π', 'πΆ', 'π', |
|
] |
|
undecided_emojis = [ |
|
'π€¨', 'π§', 'π₯Έ', 'π₯΄', 'π€·', |
|
] |
|
potty_mouth_emojis = [ |
|
'π€', 'πΏ', 'π‘', 'π€¬', 'β οΈ', 'β£οΈ', 'β’οΈ', |
|
] |
|
|
|
|
|
st.set_page_config( |
|
page_title="Toxic Comment Detection Space", |
|
page_icon="π€¬", |
|
layout="centered", |
|
initial_sidebar_state="auto", |
|
menu_items={ |
|
'Get help': None, |
|
'Report a bug': None, |
|
'About': about_page_markdown, |
|
} |
|
) |
|
|
|
|
|
@st.cache(allow_output_mutation=True, |
|
suppress_st_warning=True, |
|
show_spinner=False) |
|
def load_pipeline(model_name): |
|
with st.spinner('Loading model (this might take a while)...'): |
|
toxicity_pipeline = pipeline( |
|
'text-classification', |
|
model=model_name, |
|
tokenizer=model_name) |
|
cls_explainer = SequenceClassificationExplainer( |
|
toxicity_pipeline.model, |
|
toxicity_pipeline.tokenizer) |
|
return toxicity_pipeline, cls_explainer |
|
|
|
|
|
|
|
def format_explainer_html(html_string): |
|
"""Extract tokens with attribution-based background color.""" |
|
inside_token_prefix = '##' |
|
soup = BeautifulSoup(html_string, 'html.parser') |
|
p = soup.new_tag('p', |
|
attrs={'style': 'color: black; background-color: white;'}) |
|
|
|
current_word = None |
|
for token in soup.find_all('td')[-1].find_all('mark')[1:-1]: |
|
text = token.font.text.strip() |
|
if text.startswith(inside_token_prefix): |
|
text = text[len(inside_token_prefix):] |
|
else: |
|
|
|
if current_word is not None: |
|
p.append(current_word) |
|
p.append(' ') |
|
current_word = soup.new_tag('span') |
|
token.string = text |
|
token.attrs['style'] = f"{token.attrs['style']}; padding: 0.2em 0em;" |
|
current_word.append(token) |
|
|
|
|
|
p.append(current_word) |
|
|
|
|
|
for span in p.find_all('span'): |
|
span.find_all('mark')[0].attrs['style'] = ( |
|
f"{span.find_all('mark')[0].attrs['style']}; padding-left: 0.2em;") |
|
span.find_all('mark')[-1].attrs['style'] = ( |
|
f"{span.find_all('mark')[-1].attrs['style']}; padding-right: 0.2em;") |
|
|
|
return p |
|
|
|
|
|
def classify_comment(comment, selected_model): |
|
"""Classify the given comment and augment with additional information.""" |
|
toxicity_pipeline, cls_explainer = load_pipeline(selected_model) |
|
result = toxicity_pipeline(comment)[0] |
|
result['model_name'] = selected_model |
|
|
|
|
|
result['word_attribution'] = cls_explainer(comment, class_name="non-toxic") |
|
result['visualitsation_html'] = cls_explainer.visualize()._repr_html_() |
|
result['tokens_with_background'] = format_explainer_html( |
|
result['visualitsation_html']) |
|
|
|
|
|
label, score = result['label'], result['score'] |
|
if label == 'toxic' and score > 0.1: |
|
emoji = random.choice(potty_mouth_emojis) |
|
elif label in ['non_toxic', 'non-toxic'] and score > 0.1: |
|
emoji = random.choice(regular_emojis) |
|
else: |
|
emoji = random.choice(undecided_emojis) |
|
result.update({'text': comment, 'emoji': emoji}) |
|
|
|
|
|
st.session_state.results.append(result) |
|
|
|
|
|
|
|
if 'results' not in st.session_state: |
|
st.session_state.results = [] |
|
|
|
|
|
st.title('π€¬ Dutch Toxic Comment Detection') |
|
st.markdown("""This demo showcases two Dutch toxic comment detection models.""") |
|
|
|
|
|
st.markdown(f"""Both models were trained using a sequence classification task on a translated [Jigsaw Toxicity dataset](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge) which contains toxic online comments. |
|
The first model is a fine-tuned multilingual [DistilBERT](https://huggingface.co/distilbert-base-multilingual-cased) model whereas the second is a fine-tuned Dutch RoBERTa-based model called [RobBERT](https://huggingface.co/pdelobelle/robbert-v2-dutch-base).""") |
|
st.markdown(f"""For a more comprehensive overview of the models check out their model card on π€ Model Hub: [distilbert-base-dutch-toxic-comments]({model_names_to_URLs['ml6team/distilbert-base-dutch-cased-toxic-comments']}) and [RobBERT-dutch-base-toxic-comments]({model_names_to_URLs['ml6team/robbert-dutch-base-toxic-comments']}). |
|
""") |
|
st.markdown("""Enter a comment that you want to classify below. The model will determine the probability that it is toxic and highlights how much each token contributes to its decision: |
|
<font color="black"> |
|
<span style="background-color: rgb(250, 219, 219); opacity: 1;">r</span><span style="background-color: rgb(244, 179, 179); opacity: 1;">e</span><span style="background-color: rgb(238, 135, 135); opacity: 1;">d</span> |
|
</font> |
|
tokens indicate toxicity whereas |
|
<font color="black"> |
|
<span style="background-color: rgb(224, 251, 224); opacity: 1;">g</span><span style="background-color: rgb(197, 247, 197); opacity: 1;">re</span><span style="background-color: rgb(121, 236, 121); opacity: 1;">en</span> |
|
</font> tokens indicate the opposite. |
|
|
|
Try it yourself! π""", |
|
unsafe_allow_html=True) |
|
|
|
|
|
|
|
with st.form("dutch-toxic-comment-detection-input", clear_on_submit=False): |
|
selected_model = st.selectbox('Select a model:', model_names_to_URLs.keys(), |
|
) |
|
text = st.text_area( |
|
label='Enter the comment you want to classify below (in Dutch):') |
|
_, rightmost_col = st.columns([6,1]) |
|
submitted = rightmost_col.form_submit_button("Classify", |
|
help="Classify comment") |
|
|
|
|
|
if submitted: |
|
if text: |
|
with st.spinner('Analysing comment...'): |
|
classify_comment(text, selected_model) |
|
else: |
|
st.error('**Error**: No comment to classify. Please provide a comment.') |
|
|
|
|
|
if 'results' in st.session_state and st.session_state.results: |
|
first = True |
|
for result in st.session_state.results[::-1]: |
|
if not first: |
|
st.markdown("---") |
|
st.markdown(f"Text:\n> {result['text']}") |
|
col_1, col_2, col_3 = st.columns([1,2,2]) |
|
col_1.metric(label='', value=f"{result['emoji']}") |
|
col_2.metric(label='Label', value=f"{result['label']}") |
|
col_3.metric(label='Score', value=f"{result['score']:.3f}") |
|
st.markdown(f"Token Attribution:\n{result['tokens_with_background']}", |
|
unsafe_allow_html=True) |
|
st.caption(f"Model: {result['model_name']}") |
|
first = False |
|
|