zionia commited on
Commit
eaeabb1
1 Parent(s): 7b88e4e

update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
3
- import csv
4
 
5
  MODEL_URL = "https://huggingface.co/dsfsi/PuoBERTa-News"
6
  WEBSITE_URL = "https://www.kodiks.com/ai_solutions.html"
@@ -22,18 +22,20 @@ categories = {
22
  }
23
 
24
  def prediction(news):
25
- clasifer = pipeline("text-classification", tokenizer=tokenizer, model=model, return_all_scores=True)
26
- preds = clasifer(news)
27
  preds_dict = {categories.get(pred['label'], pred['label']): pred['score'] for pred in preds[0]}
28
  return preds_dict
29
 
30
  def file_prediction(file):
 
 
31
  if file.name.endswith('.csv'):
32
- file.seek(0)
33
  reader = csv.reader(file.read().decode('utf-8').splitlines())
34
- news_list = [row[0] for row in reader if row]
35
  else:
36
- file.seek(0)
37
  file_content = file.read().decode('utf-8')
38
  news_list = file_content.splitlines()
39
 
@@ -41,9 +43,9 @@ def file_prediction(file):
41
  for news in news_list:
42
  if news.strip():
43
  pred = prediction(news)
44
- results.append([news, pred])
45
 
46
- return results
47
 
48
  gradio_ui = gr.Interface(
49
  fn=prediction,
@@ -67,4 +69,3 @@ gradio_file_ui = gr.Interface(
67
  gradio_combined_ui = gr.TabbedInterface([gradio_ui, gradio_file_ui], ["Text Input", "File Upload"])
68
 
69
  gradio_combined_ui.launch()
70
-
 
1
  import gradio as gr
2
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
3
+ import csv
4
 
5
  MODEL_URL = "https://huggingface.co/dsfsi/PuoBERTa-News"
6
  WEBSITE_URL = "https://www.kodiks.com/ai_solutions.html"
 
22
  }
23
 
24
  def prediction(news):
25
+ classifier = pipeline("text-classification", tokenizer=tokenizer, model=model, return_all_scores=True)
26
+ preds = classifier(news)
27
  preds_dict = {categories.get(pred['label'], pred['label']): pred['score'] for pred in preds[0]}
28
  return preds_dict
29
 
30
  def file_prediction(file):
31
+ news_list = []
32
+
33
  if file.name.endswith('.csv'):
34
+ file.seek(0)
35
  reader = csv.reader(file.read().decode('utf-8').splitlines())
36
+ news_list = [row[0] for row in reader if row]
37
  else:
38
+ file.seek(0)
39
  file_content = file.read().decode('utf-8')
40
  news_list = file_content.splitlines()
41
 
 
43
  for news in news_list:
44
  if news.strip():
45
  pred = prediction(news)
46
+ results.append([news, pred]) # Return each news and its prediction as a row
47
 
48
+ return results # Gradio expects a list of lists or dicts for DataFrame
49
 
50
  gradio_ui = gr.Interface(
51
  fn=prediction,
 
69
  gradio_combined_ui = gr.TabbedInterface([gradio_ui, gradio_file_ui], ["Text Input", "File Upload"])
70
 
71
  gradio_combined_ui.launch()