nishan-chatterjee commited on
Commit
352cae4
1 Parent(s): 607f4db

gradio app version for huggingface space

Browse files
Files changed (2) hide show
  1. app.py +95 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import networkx as nx
5
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
6
+ from tqdm import tqdm
7
+
8
+ def _make_logits_consistent(x, R):
9
+ c_out = x.unsqueeze(1) + 10
10
+ c_out = c_out.expand(len(x), R.shape[1], R.shape[1])
11
+ R_batch = R.expand(len(x), R.shape[1], R.shape[1]).to(x.device)
12
+ final_out, _ = torch.max(R_batch * c_out, dim=2)
13
+ return final_out - 10
14
+
15
+ def persuasion_labels(text):
16
+ model_dir = "models"
17
+ # Initialize the graph and other necessary components
18
+ G = nx.DiGraph()
19
+ # Add edges to the graph
20
+ edges = [
21
+ ("ROOT", "Logos"),
22
+ ("Logos", "Repetition"), ("Logos", "Obfuscation, Intentional vagueness, Confusion"), ("Logos", "Reasoning"), ("Logos", "Justification"),
23
+ ("Justification", "Slogans"), ("Justification", "Bandwagon"), ("Justification", "Appeal to authority"), ("Justification", "Flag-waving"), ("Justification", "Appeal to fear/prejudice"),
24
+ ("Reasoning", "Simplification"),
25
+ ("Simplification", "Causal Oversimplification"), ("Simplification", "Black-and-white Fallacy/Dictatorship"), ("Simplification", "Thought-terminating cliché"),
26
+ ("Reasoning", "Distraction"),
27
+ ("Distraction", "Misrepresentation of Someone's Position (Straw Man)"), ("Distraction", "Presenting Irrelevant Data (Red Herring)"), ("Distraction", "Whataboutism"),
28
+ ("ROOT", "Ethos"),
29
+ ("Ethos", "Appeal to authority"), ("Ethos", "Glittering generalities (Virtue)"), ("Ethos", "Bandwagon"), ("Ethos", "Ad Hominem"), ("Ethos", "Transfer"),
30
+ ("Ad Hominem", "Doubt"), ("Ad Hominem", "Name calling/Labeling"), ("Ad Hominem", "Smears"), ("Ad Hominem", "Reductio ad hitlerum"), ("Ad Hominem", "Whataboutism"),
31
+ ("ROOT", "Pathos"),
32
+ ("Pathos", "Exaggeration/Minimisation"), ("Pathos", "Loaded Language"), ("Pathos", "Appeal to (Strong) Emotions"), ("Pathos", "Appeal to fear/prejudice"), ("Pathos", "Flag-waving"), ("Pathos", "Transfer")
33
+ ]
34
+ G.add_edges_from(edges)
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
37
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
38
+
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ model.to(device)
41
+
42
+ A = nx.to_numpy_array(G).transpose()
43
+ R = np.zeros(A.shape)
44
+ np.fill_diagonal(R, 1)
45
+ g = nx.DiGraph(A)
46
+ for i in range(len(A)):
47
+ descendants = list(nx.descendants(g, i))
48
+ if descendants:
49
+ R[i, descendants] = 1
50
+ R = torch.tensor(R).transpose(1, 0).unsqueeze(0)
51
+
52
+ encoding = tokenizer.encode_plus(
53
+ text,
54
+ add_special_tokens=True,
55
+ max_length=128,
56
+ return_token_type_ids=False,
57
+ padding="max_length",
58
+ truncation=True,
59
+ return_attention_mask=True,
60
+ return_tensors="pt",
61
+ )
62
+
63
+ with torch.no_grad():
64
+ outputs = model(
65
+ input_ids=encoding["input_ids"].to(device),
66
+ attention_mask=encoding["attention_mask"].to(device),
67
+ )
68
+ logits = _make_logits_consistent(outputs.logits, R)
69
+ logits[:, 0] = -1.0
70
+ logits = logits > 0.0
71
+ complete_predicted_hierarchy = np.array(G.nodes)[logits[0].cpu().nonzero()].flatten().tolist()
72
+
73
+ # if any label doesn't have children, add them to the list
74
+ child_only_labels = []
75
+ for label in complete_predicted_hierarchy:
76
+ if not list(G.successors(label)):
77
+ child_only_labels.append(label)
78
+
79
+ return complete_predicted_hierarchy, child_only_labels
80
+
81
+ def launch_interface():
82
+ iface = gr.Interface(
83
+ fn=persuasion_labels,
84
+ inputs=gr.Textbox(lines=5, placeholder="Enter your text here..."),
85
+ outputs=[
86
+ gr.Textbox(label="Complete Hierarchical Label List"),
87
+ gr.Textbox(label="Child-only Label List")
88
+ ],
89
+ title="Persuasion Labels",
90
+ description="Enter your text and get the persuasion labels.",
91
+ )
92
+ iface.launch(share=True)
93
+
94
+ if __name__ == "__main__":
95
+ launch_interface()
requirements.txt CHANGED
@@ -5,3 +5,4 @@ transformers
5
  tqdm
6
  sentencepiece
7
  protobuf
 
 
5
  tqdm
6
  sentencepiece
7
  protobuf
8
+ gradio