nishan-chatterjee
commited on
Commit
•
352cae4
1
Parent(s):
607f4db
gradio app version for huggingface space
Browse files- app.py +95 -0
- requirements.txt +1 -0
app.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import networkx as nx
|
5 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
def _make_logits_consistent(x, R):
|
9 |
+
c_out = x.unsqueeze(1) + 10
|
10 |
+
c_out = c_out.expand(len(x), R.shape[1], R.shape[1])
|
11 |
+
R_batch = R.expand(len(x), R.shape[1], R.shape[1]).to(x.device)
|
12 |
+
final_out, _ = torch.max(R_batch * c_out, dim=2)
|
13 |
+
return final_out - 10
|
14 |
+
|
15 |
+
def persuasion_labels(text):
|
16 |
+
model_dir = "models"
|
17 |
+
# Initialize the graph and other necessary components
|
18 |
+
G = nx.DiGraph()
|
19 |
+
# Add edges to the graph
|
20 |
+
edges = [
|
21 |
+
("ROOT", "Logos"),
|
22 |
+
("Logos", "Repetition"), ("Logos", "Obfuscation, Intentional vagueness, Confusion"), ("Logos", "Reasoning"), ("Logos", "Justification"),
|
23 |
+
("Justification", "Slogans"), ("Justification", "Bandwagon"), ("Justification", "Appeal to authority"), ("Justification", "Flag-waving"), ("Justification", "Appeal to fear/prejudice"),
|
24 |
+
("Reasoning", "Simplification"),
|
25 |
+
("Simplification", "Causal Oversimplification"), ("Simplification", "Black-and-white Fallacy/Dictatorship"), ("Simplification", "Thought-terminating cliché"),
|
26 |
+
("Reasoning", "Distraction"),
|
27 |
+
("Distraction", "Misrepresentation of Someone's Position (Straw Man)"), ("Distraction", "Presenting Irrelevant Data (Red Herring)"), ("Distraction", "Whataboutism"),
|
28 |
+
("ROOT", "Ethos"),
|
29 |
+
("Ethos", "Appeal to authority"), ("Ethos", "Glittering generalities (Virtue)"), ("Ethos", "Bandwagon"), ("Ethos", "Ad Hominem"), ("Ethos", "Transfer"),
|
30 |
+
("Ad Hominem", "Doubt"), ("Ad Hominem", "Name calling/Labeling"), ("Ad Hominem", "Smears"), ("Ad Hominem", "Reductio ad hitlerum"), ("Ad Hominem", "Whataboutism"),
|
31 |
+
("ROOT", "Pathos"),
|
32 |
+
("Pathos", "Exaggeration/Minimisation"), ("Pathos", "Loaded Language"), ("Pathos", "Appeal to (Strong) Emotions"), ("Pathos", "Appeal to fear/prejudice"), ("Pathos", "Flag-waving"), ("Pathos", "Transfer")
|
33 |
+
]
|
34 |
+
G.add_edges_from(edges)
|
35 |
+
|
36 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
37 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
38 |
+
|
39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
+
model.to(device)
|
41 |
+
|
42 |
+
A = nx.to_numpy_array(G).transpose()
|
43 |
+
R = np.zeros(A.shape)
|
44 |
+
np.fill_diagonal(R, 1)
|
45 |
+
g = nx.DiGraph(A)
|
46 |
+
for i in range(len(A)):
|
47 |
+
descendants = list(nx.descendants(g, i))
|
48 |
+
if descendants:
|
49 |
+
R[i, descendants] = 1
|
50 |
+
R = torch.tensor(R).transpose(1, 0).unsqueeze(0)
|
51 |
+
|
52 |
+
encoding = tokenizer.encode_plus(
|
53 |
+
text,
|
54 |
+
add_special_tokens=True,
|
55 |
+
max_length=128,
|
56 |
+
return_token_type_ids=False,
|
57 |
+
padding="max_length",
|
58 |
+
truncation=True,
|
59 |
+
return_attention_mask=True,
|
60 |
+
return_tensors="pt",
|
61 |
+
)
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
outputs = model(
|
65 |
+
input_ids=encoding["input_ids"].to(device),
|
66 |
+
attention_mask=encoding["attention_mask"].to(device),
|
67 |
+
)
|
68 |
+
logits = _make_logits_consistent(outputs.logits, R)
|
69 |
+
logits[:, 0] = -1.0
|
70 |
+
logits = logits > 0.0
|
71 |
+
complete_predicted_hierarchy = np.array(G.nodes)[logits[0].cpu().nonzero()].flatten().tolist()
|
72 |
+
|
73 |
+
# if any label doesn't have children, add them to the list
|
74 |
+
child_only_labels = []
|
75 |
+
for label in complete_predicted_hierarchy:
|
76 |
+
if not list(G.successors(label)):
|
77 |
+
child_only_labels.append(label)
|
78 |
+
|
79 |
+
return complete_predicted_hierarchy, child_only_labels
|
80 |
+
|
81 |
+
def launch_interface():
|
82 |
+
iface = gr.Interface(
|
83 |
+
fn=persuasion_labels,
|
84 |
+
inputs=gr.Textbox(lines=5, placeholder="Enter your text here..."),
|
85 |
+
outputs=[
|
86 |
+
gr.Textbox(label="Complete Hierarchical Label List"),
|
87 |
+
gr.Textbox(label="Child-only Label List")
|
88 |
+
],
|
89 |
+
title="Persuasion Labels",
|
90 |
+
description="Enter your text and get the persuasion labels.",
|
91 |
+
)
|
92 |
+
iface.launch(share=True)
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
launch_interface()
|
requirements.txt
CHANGED
@@ -5,3 +5,4 @@ transformers
|
|
5 |
tqdm
|
6 |
sentencepiece
|
7 |
protobuf
|
|
|
|
5 |
tqdm
|
6 |
sentencepiece
|
7 |
protobuf
|
8 |
+
gradio
|