Spaces:
Runtime error
Runtime error
pages
Browse files- main.py +26 -1
- media/Toxic_labeled.csv +0 -0
- media/oritoxic.jpg +0 -0
- media/page1.py +0 -1
- media/page2.py +0 -1
- pages/page1.py +0 -1
- pages/page2.py +0 -1
- pages/task1.py +109 -0
- pages/task3.py +89 -0
- pages/toxicapp.py +156 -0
- requirements.txt +109 -0
- srcs/model_modify.pth +3 -0
main.py
CHANGED
@@ -1 +1,26 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
st.markdown("""
|
5 |
+
<style>
|
6 |
+
section[data-testid="stSidebar"][aria-expanded="true"]{
|
7 |
+
display: none;
|
8 |
+
}
|
9 |
+
</style>
|
10 |
+
""", unsafe_allow_html=True)
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
st.title('π & β‘π¨ππͺ«π‘')
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
st.write('choose your option')
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
st.page_link("pages/task1.py", label="Kinopoisk", icon='π₯')
|
23 |
+
st.page_link("pages/toxicapp.py", label="Personality", icon='β οΈ')
|
24 |
+
st.page_link("pages/task3.py", label="GPT", icon='π²')
|
25 |
+
|
26 |
+
st.header(f'''made by: Alexey Kamaev & Marina Kochetova''')
|
media/Toxic_labeled.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
media/oritoxic.jpg
ADDED
media/page1.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
print()
|
|
|
|
media/page2.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
print()
|
|
|
|
pages/page1.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
print()
|
|
|
|
pages/page2.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
print()
|
|
|
|
pages/task1.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
import joblib
|
4 |
+
|
5 |
+
from transformers import AutoTokenizer, AutoModel
|
6 |
+
import torch
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
|
11 |
+
with open('srcs/vocab_to_int.json', encoding='utf-8') as f:
|
12 |
+
vocab_to_int = json.load(f)
|
13 |
+
|
14 |
+
with open('srcs/int_to_vocab.json', encoding='utf-8') as f:
|
15 |
+
int_to_vocab = json.load(f)
|
16 |
+
|
17 |
+
VOCAB_SIZE = len(vocab_to_int) + 1
|
18 |
+
EMBEDDING_DIM = 64 # embedding_dim
|
19 |
+
SEQ_LEN = 350
|
20 |
+
HIDDEN_SIZE = 64
|
21 |
+
|
22 |
+
with open('srcs/embedding_matrix.npy', 'rb') as f:
|
23 |
+
embedding_matrix = np.load(f)
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
from srcs.srcs import LSTMConcatAttentionB
|
28 |
+
from srcs.srcs import Text_ex, clean
|
29 |
+
|
30 |
+
|
31 |
+
log_reg_vec = joblib.load('srcs/log_reg_vec.sav')
|
32 |
+
|
33 |
+
log_reg_bert = joblib.load('srcs/log_reg_bert.sav')
|
34 |
+
|
35 |
+
texter = Text_ex(clean, vocab_to_int, SEQ_LEN)
|
36 |
+
lstm = LSTMConcatAttentionB()
|
37 |
+
|
38 |
+
lstm.load_state_dict(torch.load('srcs/lstm.pt'))
|
39 |
+
|
40 |
+
|
41 |
+
vectorizer = pickle.load(open("srcs/vectorizer.pickle", "rb"))
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained('srcs/tokenzier', local_files_only=True)
|
43 |
+
bert = torch.load('srcs/bert.pt')
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
from srcs.srcs import PredMaker
|
48 |
+
predM = PredMaker(model1=log_reg_vec, model2=lstm, rubert=bert, model3=log_reg_bert, vectorizer=vectorizer, texter=texter, clean_func=clean, tokenizer=tokenizer, itc=int_to_vocab)
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
import streamlit as st
|
54 |
+
st.markdown("""
|
55 |
+
<style>
|
56 |
+
section[data-testid="stSidebar"][aria-expanded="true"]{
|
57 |
+
display: none;
|
58 |
+
}
|
59 |
+
</style>
|
60 |
+
""", unsafe_allow_html=True)
|
61 |
+
|
62 |
+
|
63 |
+
st.title('ΠΠΈΠ½ΠΎΠΏΠΎΠΈΡΠΊ')
|
64 |
+
st.page_link("main.py", label="Home", icon='π ')
|
65 |
+
|
66 |
+
import streamlit as st
|
67 |
+
|
68 |
+
txt = st.text_area(
|
69 |
+
"ΠΠ²Π΅Π΄ΠΈΡΠ΅ ΡΡΠ΄Π° ΠΎΡΠ·ΡΠ² Π½Π° ΡΠΈΠ»ΡΠΌ:",
|
70 |
+
"",
|
71 |
+
)
|
72 |
+
|
73 |
+
# st.write(f'ΠΠ²Π΅Π΄Π΅Π½ΠΎ {len(txt)} ΡΠΈΠΌΠ²ΠΎΠ»ΠΎΠ².')
|
74 |
+
|
75 |
+
if txt == '' or len(txt) < 12:
|
76 |
+
if len(txt) >= 1:
|
77 |
+
st.write('ΠΠ²Π΅Π΄ΠΈ ΡΡΠΎ-Π½ΠΈΠ±ΡΠ΄Ρ Π½ΠΎΡΠΌΠ°Π»ΡΠ½ΠΎΠ΅')
|
78 |
+
else:
|
79 |
+
text = txt
|
80 |
+
res1, res2, res3, t, att, *times = predM(text)
|
81 |
+
t_ = t[0].numpy()[0]
|
82 |
+
k = len(t[1].split()) + 1
|
83 |
+
|
84 |
+
labels = [int_to_vocab[str(x)] for x in t_ if int_to_vocab.get(str(x))]
|
85 |
+
|
86 |
+
if list(set(labels[-k:])) == ["<pad>"]:
|
87 |
+
st.write('ΠΠ°Π²Π°ΠΉ ΠΏΠΎ Π½ΠΎΠ²ΠΎΠΉ ΠΌΠΈΡΠ°, Π²ΡΡ @**##')
|
88 |
+
st.write(set(labels[-k:]))
|
89 |
+
else:
|
90 |
+
st.toast('!', icon='π')
|
91 |
+
di = {0:'ΠΠ»ΠΎΡ
ΠΎ',1:'ΠΠ΅ΠΉΡΡΠ°Π»ΡΠ½ΠΎ',2:'Π₯ΠΎΡΠΎΡΠΎ'}
|
92 |
+
d = {0: st.error, 1: st.warning, 2: st.success}
|
93 |
+
|
94 |
+
d[res1](f'ΠΡΠ΅Π΄ΡΠΊΠ°Π·Π°Π½ΠΈΠ΅ 1-ΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ: {di[res1]}')
|
95 |
+
st.write(f'Π²ΡΠ΅ΠΌΡ = {round(times[0],3)}c, f1-score = 0.64')
|
96 |
+
|
97 |
+
d[res2](f'ΠΡΠ΅Π΄ΡΠΊΠ°Π·Π°Π½ΠΈΠ΅ 2-ΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ: {di[res2]}')
|
98 |
+
st.write(f'Π²ΡΠ΅ΠΌΡ = {round(times[0],3)}c, f1-score = 0.70')
|
99 |
+
|
100 |
+
d[res3](f'ΠΡΠ΅Π΄ΡΠΊΠ°Π·Π°Π½ΠΈΠ΅ 3-ΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ: {di[res3]}')
|
101 |
+
st.write(f'Π²ΡΠ΅ΠΌΡ = {round(times[0],3)}c, f1-score = 0.66')
|
102 |
+
|
103 |
+
|
104 |
+
plt.figure(figsize=(8, 8))
|
105 |
+
plt.barh(np.arange(len(t_))[-k:], att[-k:])
|
106 |
+
plt.yticks(ticks = np.arange(len(t_))[-k:], labels = labels[-k:])
|
107 |
+
plt.title(f'f1-score = 0.7\npred = {di[res2]}\ntime = {round(times[1],3)}c');
|
108 |
+
st.set_option('deprecation.showPyplotGlobalUse', False)
|
109 |
+
st.pyplot()
|
pages/task3.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
3 |
+
import time
|
4 |
+
|
5 |
+
def generate_text(model, tokenizer, prompt, max_length, num_generations, temperature):
|
6 |
+
generated_texts = []
|
7 |
+
|
8 |
+
for _ in range(num_generations):
|
9 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
10 |
+
output = model.generate(
|
11 |
+
input_ids,
|
12 |
+
max_length=max_length,
|
13 |
+
temperature=temperature,
|
14 |
+
num_return_sequences=1
|
15 |
+
)
|
16 |
+
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
17 |
+
generated_texts.append(generated_text)
|
18 |
+
|
19 |
+
return generated_texts
|
20 |
+
|
21 |
+
button_style = """
|
22 |
+
<style>
|
23 |
+
.center-align {
|
24 |
+
display: flex;
|
25 |
+
justify-content: center;
|
26 |
+
|
27 |
+
</style>
|
28 |
+
"""
|
29 |
+
|
30 |
+
DEVICE = 'cpu'
|
31 |
+
|
32 |
+
# ΠΠ°Π³ΡΡΠ·ΠΊΠ° ΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΠ΅Π»ΡΡΠΊΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ ΠΈ ΡΠΎΠΊΠ΅Π½ΠΈΠ·Π°ΡΠΎΡΠ° (Π·Π°ΠΌΠ΅Π½ΠΈΡΠ΅ Π½Π° ΡΠ²ΠΎΠΈ ΠΏΡΡΠΈ ΠΈ ΠΌΠΎΠ΄Π΅Π»Ρ)
|
33 |
+
# model_path = "sberbank-ai/rugpt3small_based_on_gpt2"
|
34 |
+
# tokenizer_path = "sberbank-ai/rugpt3small_based_on_gpt2"
|
35 |
+
|
36 |
+
# model = GPT2LMHeadModel.from_pretrained(model_path).to(DEVICE)
|
37 |
+
# tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
|
38 |
+
|
39 |
+
st.markdown("""
|
40 |
+
<style>
|
41 |
+
section[data-testid="stSidebar"][aria-expanded="true"]{
|
42 |
+
display: none;
|
43 |
+
}
|
44 |
+
</style>
|
45 |
+
""", unsafe_allow_html=True)
|
46 |
+
|
47 |
+
st.write("## Text generator")
|
48 |
+
st.page_link("main.py", label="Home", icon='π ')
|
49 |
+
st.markdown(
|
50 |
+
"""
|
51 |
+
This streamlit-app can generate text using your prompt
|
52 |
+
"""
|
53 |
+
)
|
54 |
+
# ΠΠ²ΠΎΠ΄ ΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΠ΅Π»ΡΡΠΊΠΎΠ³ΠΎ prompt
|
55 |
+
prompt = st.text_area("Enter your prompt:")
|
56 |
+
|
57 |
+
# ΠΠ°ΡΠ°ΠΌΠ΅ΡΡΡ Π³Π΅Π½Π΅ΡΠ°ΡΠΈΠΈ
|
58 |
+
max_length = st.slider("Max length of generated text:", min_value=10, max_value=500, value=100, step=10)
|
59 |
+
num_generations = st.slider("Number of generations:", min_value=1, max_value=10, value=3, step=1)
|
60 |
+
temperature = st.slider("Temperature:", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
|
61 |
+
try:
|
62 |
+
if st.button("Generate text"):
|
63 |
+
start_time = time.time()
|
64 |
+
generated_texts = generate_text(model, tokenizer, prompt, max_length, num_generations, temperature)
|
65 |
+
end_time = time.time()
|
66 |
+
|
67 |
+
st.subheader("Π‘Π³Π΅Π½Π΅ΡΠΈΡΠΎΠ²Π°Π½Π½ΡΠΉ ΡΠ΅ΠΊΡΡ:")
|
68 |
+
for i, text in enumerate(generated_texts, start=1):
|
69 |
+
st.write(f"ΠΠ΅Π½Π΅ΡΠ°ΡΠΈΡ {i}:\n{text}")
|
70 |
+
|
71 |
+
generation_time = end_time - start_time
|
72 |
+
st.write(f"\nΠΡΠ΅ΠΌΡ Π³Π΅Π½Π΅ΡΠ°ΡΠΈΠΈ: {generation_time:.2f} ΡΠ΅ΠΊΡΠ½Π΄")
|
73 |
+
|
74 |
+
st.markdown(button_style, unsafe_allow_html=True) # ΠΡΠΈΠΌΠ΅Π½ΡΠ΅ΠΌ ΡΡΠΈΠ»Ρ ΠΊ ΠΊΠ½ΠΎΠΏΠΊΠ΅
|
75 |
+
st.markdown(
|
76 |
+
"""
|
77 |
+
<style>
|
78 |
+
div[data-baseweb="textarea"] {
|
79 |
+
border: 2px solid #3498db; /* Π¦Π²Π΅Ρ Π³ΡΠ°Π½ΠΈΡΡ */
|
80 |
+
border-radius: 5px; /* ΠΠ°ΠΊΡΡΠ³Π»Π΅Π½Π½ΡΠ΅ ΡΠ³Π»Ρ */
|
81 |
+
background-color: #ecf0f1; /* Π¦Π²Π΅Ρ ΡΠΎΠ½Π° */
|
82 |
+
padding: 10px; /* ΠΠΎΠ»Ρ Π²ΠΎΠΊΡΡΠ³ ΡΠ΅ΠΊΡΡΠΎΠ²ΠΎΠ³ΠΎ ΠΏΠΎΠ»Ρ */
|
83 |
+
}
|
84 |
+
</style>
|
85 |
+
""",
|
86 |
+
unsafe_allow_html=True,
|
87 |
+
)
|
88 |
+
except:
|
89 |
+
st.write('ΠΠΎΠ΄Π΅Π»Ρ Π² ΡΠ°Π·ΡΠ°Π±ΠΎΡΠΊΠ΅ ( οΎ οΎο½°οΎ)οΎ( οΎ οΎο½°οΎ)οΎ( οΎ οΎο½°οΎ)οΎ( οΎ οΎο½°οΎ)οΎ')
|
pages/toxicapp.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
4 |
+
import pandas as pd
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from PIL import Image
|
7 |
+
import os
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
model_path = 'srcs/model_modify.pth'
|
11 |
+
|
12 |
+
# ΡΠΎΠΊΠ΅Π½ΠΈΠ·Π°ΡΠΎΡ
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained('cointegrated/rubert-tiny-toxicity')
|
14 |
+
|
15 |
+
model = AutoModelForSequenceClassification.from_pretrained('cointegrated/rubert-tiny-toxicity', num_labels=1, ignore_mismatched_sizes=True)
|
16 |
+
# Π²Π΅ΡΠΎΠ² ΠΌΠΎΠ΄ΠΈΡΠΈΡΠΈΡΠΎΠ²Π°Π½Π½ΠΎΠΉ ΠΌΠΎΠ΄Π΅Π»ΠΈ
|
17 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
|
18 |
+
image = Image.open("media/oritoxic.jpg")
|
19 |
+
|
20 |
+
# df = pd.read_csv("media/Toxic_labeled.csv")
|
21 |
+
loss_values = [0.4063596375772262, 0.402279906166038, 0.3998144585561736, 0.39567733055365567,
|
22 |
+
0.3921396666608141, 0.38956182373070186, 0.3866641920902114, 0.3879134839351564,
|
23 |
+
0.38288725781591604, 0.38198364493999004]
|
24 |
+
|
25 |
+
#ΠΠΎΠΊΠΎΠ²Π°Ρ ΠΏΠ°Π½Π΅Π»Ρ
|
26 |
+
selected_option = st.sidebar.selectbox("ΠΡΠ±Π΅ΡΠΈΡΠ΅ ΠΈΠ· ΡΠΏΠΈΡΠΊΠ°", ["ΠΠΏΡΠ΅Π΄Π΅Π»Π΅Π½ΠΈΠ΅ ΡΠΎΠΊΡΠΈΡΠ½ΠΎΡΡΡ ΡΠ΅ΠΊΡΡΠ°", "ΠΠ½ΡΠΎΡΠΌΠ°ΡΠΈΡ ΠΎ Π΄Π°ΡΠ°ΡΠ΅ΡΠ΅", "ΠΠ½ΡΠΎΡΠΌΠ°ΡΠΈΡ ΠΎ ΠΌΠΎΠ΄Π΅Π»ΠΈ"])
|
27 |
+
|
28 |
+
#st.title("ΠΠ»Π°Π²Π½Π°Ρ ΡΡΡΠ°Π½ΠΈΡΠ°")
|
29 |
+
|
30 |
+
|
31 |
+
if selected_option == "ΠΠΏΡΠ΅Π΄Π΅Π»Π΅Π½ΠΈΠ΅ ΡΠΎΠΊΡΠΈΡΠ½ΠΎΡΡΡ ΡΠ΅ΠΊΡΡΠ°":
|
32 |
+
|
33 |
+
|
34 |
+
st.markdown("<h1 style='text-align: center;'>ΠΡΠΈΠ»ΠΎΠΆΠ΅Π½ΠΈΠ΅ Π΄Π»Ρ ΠΎΠΏΡΠ΅Π΄Π΅Π»Π΅Π½ΠΈΡ ΡΠΎΠΊΡΠΈΡΠ½ΠΎΡΡΠΈ ΡΠ΅ΠΊΡΡΠ°</h1>",
|
35 |
+
unsafe_allow_html=True)
|
36 |
+
st.image(image, use_column_width=True)
|
37 |
+
user_input = st.text_area("")
|
38 |
+
|
39 |
+
|
40 |
+
# Π€ΡΠ½ΠΊΡΠΈΡ ΠΏΡΠ΅Π΄ΡΠΊΠ°Π·Π°Π½ΠΈΡ ΡΠΎΠΊΡΠΈΡΠ½ΠΎΡΡΠΈ
|
41 |
+
|
42 |
+
def predict_toxicity(text):
|
43 |
+
inputs = tokenizer(text, return_tensors="pt")
|
44 |
+
outputs = model(**inputs)
|
45 |
+
logits = outputs.logits
|
46 |
+
probability = torch.sigmoid(logits).item()
|
47 |
+
prediction = "ΡΠΎΠΊΡΠΈΡΠ½ΡΠΉ" if probability >= 0.5 else "Π½Π΅ ΡΠΎΠΊΡΠΈΡΠ½ΡΠΉ"
|
48 |
+
return prediction, probability
|
49 |
+
# Π’ΡΠΊ Π½Π° ΠΊΠ½ΠΎΠΏΡ
|
50 |
+
if st.button("ΠΡΠ΅Π½ΠΈΡΡ ΡΠΎΠΊΡΠΈΡΠ½ΠΎΡΡΡ"):
|
51 |
+
if user_input:
|
52 |
+
prediction, toxicity_probability = predict_toxicity(user_input)
|
53 |
+
st.write(f'ΠΠ΅ΡΠΎΡΡΠ½ΠΎΡΡΡ ΡΠΎΠΊΡΠΈΡΠ½ΠΎΡΡΠΈ: {toxicity_probability:.4f}')
|
54 |
+
|
55 |
+
# ΠΡΠΎΠ³ΡΠ΅ΡΡ Π±Π°Ρ
|
56 |
+
if 'toxicity_probability' in locals():
|
57 |
+
progress_percentage = int(toxicity_probability * 100)
|
58 |
+
progress_bar_color = f'linear-gradient(to right, rgba(0, 0, 255, 0.5) {progress_percentage}%, rgba(255, 0, 0, 0.5) {progress_percentage}%)'
|
59 |
+
st.markdown(f'<div style="background: {progress_bar_color}; height: 20px; border-radius: 5px;"></div>',
|
60 |
+
unsafe_allow_html=True)
|
61 |
+
|
62 |
+
elif selected_option == "ΠΠ½ΡΠΎΡΠΌΠ°ΡΠΈΡ ΠΎ Π΄Π°ΡΠ°ΡΠ΅ΡΠ΅":
|
63 |
+
st.header("ΠΠ½ΡΠΎΡΠΌΠ°ΡΠΈΡ ΠΎ Π΄Π°ΡΠ°ΡΠ΅ΡΠ΅:")
|
64 |
+
st.dataframe(df.head())
|
65 |
+
st.write(f"ΠΠ±ΡΠ΅ΠΌ Π²ΡΠ±ΠΎΡΠΊΠΈ: 14412")
|
66 |
+
st.subheader("ΠΠ°Π»Π°Π½Ρ ΠΊΠ»Π°ΡΡΠΎΠ² Π² Π΄Π°ΡΠ°ΡΠ΅ΡΠ΅:")
|
67 |
+
st.write(f"ΠΠΎΠ»ΠΈΡΠ΅ΡΡΠ²ΠΎ Π·Π°ΠΏΠΈΡΠ΅ΠΉ Π² ΠΊΠ»Π°ΡΡΠ΅ 0.0: {len(df[df['toxic'] == 0.0])}")
|
68 |
+
st.write(f"ΠΠΎΠ»ΠΈΡΠ΅ΡΡΠ²ΠΎ Π·Π°ΠΏΠΈΡΠ΅ΠΉ Π² ΠΊΠ»Π°ΡΡΠ΅ 1.0: {len(df[df['toxic'] == 1.0])}")
|
69 |
+
fig, ax = plt.subplots()
|
70 |
+
df['toxic'].value_counts().plot(kind='bar', ax=ax, color=['skyblue', 'orange'])
|
71 |
+
ax.set_xticklabels(['ΠΠ΅ ΡΠΎΠΊΡΠΈΡΠ½ΡΠΉ', 'Π’ΠΎΠΊΡΠΈΡΠ½ΡΠΉ'], rotation=0)
|
72 |
+
ax.set_xlabel('ΠΠ»Π°ΡΡ')
|
73 |
+
ax.set_ylabel('ΠΠΎΠ»ΠΈΡΠ΅ΡΡΠ²ΠΎ Π·Π°ΠΏΠΈΡΠ΅ΠΉ')
|
74 |
+
ax.set_title('Π Π°ΡΠΏΡΠ΅Π΄Π΅Π»Π΅Π½ΠΈΠ΅ ΠΏΠΎ ΠΊΠ»Π°ΡΡΠ°ΠΌ')
|
75 |
+
st.pyplot(fig)
|
76 |
+
|
77 |
+
elif selected_option == "ΠΠ½ΡΠΎΡΠΌΠ°ΡΠΈΡ ΠΎ ΠΌΠΎΠ΄Π΅Π»ΠΈ":
|
78 |
+
st.subheader("ΠΠ½ΡΠΎΡΠΌΠ°ΡΠΈΡ ΠΎ ΠΌΠΎΠ΄Π΅Π»ΠΈ:")
|
79 |
+
st.write(f"ΠΠΎΠ΄Π΅Π»Ρ: Rubert tiny toxicity")
|
80 |
+
st.subheader("ΠΠ½ΡΠΎΡΠΌΠ°ΡΠΈΡ ΠΎ ΠΏΡΠΎΡΠ΅ΡΡΠ΅ ΠΎΠ±ΡΡΠ΅Π½ΠΈΡ")
|
81 |
+
|
82 |
+
# Π³ΡΠ°ΡΠΈΠΊ Π»ΠΎΡΡΠ°
|
83 |
+
#st.subheader("ΠΡΠ°ΡΠΈΠΊ ΠΏΠΎΡΠ΅ΡΡ Π² ΠΏΡΠΎΡΠ΅ΡΡΠ΅ ΠΎΠ±ΡΡΠ΅Π½ΠΈΡ")
|
84 |
+
#st.line_chart([0.5181976270121774, 0.4342067330899996, 0.41386983832460666]) # ΠΠ°ΠΌΠ΅Π½ΠΈΡΠ΅ Π΄Π°Π½Π½ΡΠΌΠΈ ΠΈΠ· Π²Π°ΡΠΈΡ
ΡΠΏΠΎΡ
|
85 |
+
for epoch, loss in enumerate(loss_values, start=1):
|
86 |
+
st.write(f"<b>Epoch {epoch}/{len(loss_values)}, Loss:</b> {loss}<br>", unsafe_allow_html=True)
|
87 |
+
st.markdown(
|
88 |
+
"""
|
89 |
+
<b>ΠΠΎΠ»ΠΈΡΠ΅ΡΡΠ²ΠΎ ΡΠΏΠΎΡ
:</b> 10
|
90 |
+
<b>Π Π°Π·ΠΌΠ΅Ρ Π±Π°ΡΡΠ°:</b> 8
|
91 |
+
<b>ΠΠΏΡΠΈΠΌΠΈΠ·Π°ΡΠΎΡ:</b> Adam
|
92 |
+
<b>Π€ΡΠ½ΠΊΡΠΈΡ ΠΏΠΎΡΠ΅ΡΡ:</b> BCEWithLogitsLoss
|
93 |
+
<b>learning rate:</b> 0.00001
|
94 |
+
""",
|
95 |
+
unsafe_allow_html=True
|
96 |
+
)
|
97 |
+
|
98 |
+
st.subheader("ΠΠ΅ΡΡΠΈΠΊΠΈ ΠΌΠΎΠ΄Π΅Π»ΠΈ:")
|
99 |
+
st.write(f"Accuracy: {0.8366:.4f}")
|
100 |
+
st.write(f"Precision: {0.8034:.4f}")
|
101 |
+
st.write(f"Recall: {0.6777:.4f}")
|
102 |
+
st.write(f"F1 Score: {0.7352:.4f}")
|
103 |
+
|
104 |
+
|
105 |
+
st.subheader("ΠΠΎΠ΄")
|
106 |
+
|
107 |
+
|
108 |
+
bert_model_code = """
|
109 |
+
|
110 |
+
model = BertModel(
|
111 |
+
embeddings=BertEmbeddings(
|
112 |
+
word_embeddings=Embedding(29564, 312, padding_idx=0),
|
113 |
+
position_embeddings=Embedding(512, 312),
|
114 |
+
token_type_embeddings=Embedding(2, 312),
|
115 |
+
LayerNorm=LayerNorm((312,), eps=1e-12, elementwise_affine=True),
|
116 |
+
dropout=Dropout(p=0.1, inplace=False),
|
117 |
+
),
|
118 |
+
encoder=BertEncoder(
|
119 |
+
layer=ModuleList(
|
120 |
+
BertLayer(
|
121 |
+
attention=BertAttention(
|
122 |
+
self=BertSelfAttention(
|
123 |
+
query=Linear(in_features=312, out_features=312, bias=True),
|
124 |
+
key=Linear(in_features=312, out_features=312, bias=True),
|
125 |
+
value=Linear(in_features=312, out_features=312, bias=True),
|
126 |
+
dropout=Dropout(p=0.1, inplace=False),
|
127 |
+
),
|
128 |
+
output=BertSelfOutput(
|
129 |
+
dense=Linear(in_features=312, out_features=312, bias=True),
|
130 |
+
LayerNorm=LayerNorm((312,), eps=1e-12, elementwise_affine=True),
|
131 |
+
dropout=Dropout(p=0.1, inplace=False),
|
132 |
+
),
|
133 |
+
),
|
134 |
+
intermediate=BertIntermediate(
|
135 |
+
dense=Linear(in_features=312, out_features=600, bias=True),
|
136 |
+
intermediate_act_fn=GELUActivation(),
|
137 |
+
),
|
138 |
+
output=BertOutput(
|
139 |
+
dense=Linear(in_features=600, out_features=312, bias=True),
|
140 |
+
LayerNorm=LayerNorm((312,), eps=1e-12, elementwise_affine=True),
|
141 |
+
dropout=Dropout(p=0.1, inplace=False),
|
142 |
+
),
|
143 |
+
)
|
144 |
+
)
|
145 |
+
),
|
146 |
+
pooler=BertPooler(
|
147 |
+
dense=Linear(in_features=312, out_features=312, bias=True),
|
148 |
+
activation=Tanh(),
|
149 |
+
),
|
150 |
+
dropout=Dropout(p=0.1, inplace=False),
|
151 |
+
classifier=Linear(in_features=312, out_features=1, bias=True),
|
152 |
+
)
|
153 |
+
"""
|
154 |
+
|
155 |
+
# ΠΡΠΎΠ±ΡΠ°ΠΆΠ΅Π½ΠΈΠ΅ ΠΊΠΎΠ΄Π° Π² Streamlit
|
156 |
+
st.code(bert_model_code, language="python")
|
requirements.txt
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
altair==5.2.0
|
2 |
+
asttokens==2.4.1
|
3 |
+
attrs==23.2.0
|
4 |
+
blinker==1.7.0
|
5 |
+
cachetools==5.3.3
|
6 |
+
certifi==2024.2.2
|
7 |
+
charset-normalizer==3.3.2
|
8 |
+
click==8.1.7
|
9 |
+
colorama==0.4.6
|
10 |
+
comm==0.2.1
|
11 |
+
contourpy==1.2.0
|
12 |
+
cycler==0.12.1
|
13 |
+
debugpy==1.8.1
|
14 |
+
decorator==5.1.1
|
15 |
+
executing==2.0.1
|
16 |
+
filelock==3.9.0
|
17 |
+
fonttools==4.49.0
|
18 |
+
fsspec==2024.2.0
|
19 |
+
gensim==4.3.2
|
20 |
+
gitdb==4.0.11
|
21 |
+
GitPython==3.1.42
|
22 |
+
huggingface-hub==0.21.4
|
23 |
+
idna==3.6
|
24 |
+
importlib-metadata==7.0.1
|
25 |
+
ipykernel==6.29.3
|
26 |
+
ipython==8.22.1
|
27 |
+
jedi==0.19.1
|
28 |
+
Jinja2==3.1.2
|
29 |
+
joblib==1.3.2
|
30 |
+
jsonschema==4.21.1
|
31 |
+
jsonschema-specifications==2023.12.1
|
32 |
+
jupyter_client==8.6.0
|
33 |
+
jupyter_core==5.7.1
|
34 |
+
kiwisolver==1.4.5
|
35 |
+
lightning-utilities==0.10.1
|
36 |
+
markdown-it-py==3.0.0
|
37 |
+
MarkupSafe==2.1.3
|
38 |
+
matplotlib==3.8.3
|
39 |
+
matplotlib-inline==0.1.6
|
40 |
+
mdurl==0.1.2
|
41 |
+
mpmath==1.3.0
|
42 |
+
nest-asyncio==1.6.0
|
43 |
+
networkx==3.2.1
|
44 |
+
nltk==3.8.1
|
45 |
+
numpy==1.26.3
|
46 |
+
opencv-python==4.9.0.80
|
47 |
+
packaging==23.2
|
48 |
+
pandas==2.2.1
|
49 |
+
parso==0.8.3
|
50 |
+
pillow==10.2.0
|
51 |
+
platformdirs==4.2.0
|
52 |
+
prompt-toolkit==3.0.43
|
53 |
+
protobuf==4.25.3
|
54 |
+
psutil==5.9.8
|
55 |
+
pure-eval==0.2.2
|
56 |
+
py-cpuinfo==9.0.0
|
57 |
+
pyarrow==15.0.0
|
58 |
+
pydeck==0.8.1b0
|
59 |
+
Pygments==2.17.2
|
60 |
+
pymystem3==0.2.0
|
61 |
+
pyparsing==3.1.1
|
62 |
+
pyproject-toml==0.0.10
|
63 |
+
python-dateutil==2.8.2
|
64 |
+
pytz==2024.1
|
65 |
+
pywin32==306
|
66 |
+
PyYAML==6.0.1
|
67 |
+
pyzmq==25.1.2
|
68 |
+
referencing==0.33.0
|
69 |
+
regex==2023.12.25
|
70 |
+
requests==2.31.0
|
71 |
+
rich==13.7.1
|
72 |
+
rpds-py==0.18.0
|
73 |
+
safetensors==0.4.2
|
74 |
+
scikit-learn==1.4.1.post1
|
75 |
+
scipy==1.12.0
|
76 |
+
seaborn==0.13.2
|
77 |
+
setuptools==69.1.1
|
78 |
+
six==1.16.0
|
79 |
+
smart-open==7.0.1
|
80 |
+
smmap==5.0.1
|
81 |
+
stack-data==0.6.3
|
82 |
+
stqdm==0.0.5
|
83 |
+
streamlit==1.31.1
|
84 |
+
sympy==1.12
|
85 |
+
tenacity==8.2.3
|
86 |
+
thop==0.1.1.post2209072238
|
87 |
+
threadpoolctl==3.3.0
|
88 |
+
tokenizers==0.15.2
|
89 |
+
toml==0.10.2
|
90 |
+
toolz==0.12.1
|
91 |
+
torch==2.2.1+cu121
|
92 |
+
torchaudio==2.2.1+cu121
|
93 |
+
torchmetrics==1.3.1
|
94 |
+
torchvision==0.17.1+cu121
|
95 |
+
tornado==6.4
|
96 |
+
tqdm==4.66.2
|
97 |
+
traitlets==5.14.1
|
98 |
+
transformers==4.38.2
|
99 |
+
typing_extensions==4.8.0
|
100 |
+
tzdata==2024.1
|
101 |
+
tzlocal==5.2
|
102 |
+
ultralytics==8.1.20
|
103 |
+
urllib3==2.2.1
|
104 |
+
validators==0.22.0
|
105 |
+
watchdog==4.0.0
|
106 |
+
wcwidth==0.2.13
|
107 |
+
wheel==0.42.0
|
108 |
+
wrapt==1.16.0
|
109 |
+
zipp==3.17.0
|
srcs/model_modify.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7b97eff933dde60a624143ec9416613b9ffc74b8c4877a00e05ebecefeb1e485
|
3 |
+
size 47160284
|