zionia commited on
Commit
82c6a1e
1 Parent(s): f0e5035

update upload function

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -22,26 +22,30 @@ categories = {
22
  }
23
 
24
  def prediction(news):
25
- clasifer = pipeline("sentiment-analysis", tokenizer=tokenizer, model=model, return_all_scores=True)
26
  preds = clasifer(news)
27
  preds_dict = {categories.get(pred['label'], pred['label']): pred['score'] for pred in preds[0]}
28
  return preds_dict
29
 
30
  def file_prediction(file):
31
  if file.name.endswith('.csv'):
32
- df = pd.read_csv(file)
33
- news_list = df.iloc[:, 0].tolist()
34
  else:
35
  file.seek(0)
36
- file_content = file.read().decode('utf-8')
37
  news_list = file_content.splitlines()
38
-
39
  results = []
40
  for news in news_list:
41
  if news.strip():
42
- results.append(prediction(news))
 
43
 
44
- return pd.DataFrame(results, index=news_list)
 
 
 
45
 
46
 
47
 
 
22
  }
23
 
24
  def prediction(news):
25
+ clasifer = pipeline("text-classification", tokenizer=tokenizer, model=model, return_all_scores=True)
26
  preds = clasifer(news)
27
  preds_dict = {categories.get(pred['label'], pred['label']): pred['score'] for pred in preds[0]}
28
  return preds_dict
29
 
30
  def file_prediction(file):
31
  if file.name.endswith('.csv'):
32
+ df = pd.read_csv(file)
33
+ news_list = df.iloc[:, 0].tolist()
34
  else:
35
  file.seek(0)
36
+ file_content = file.read().decode('utf-8')
37
  news_list = file_content.splitlines()
38
+
39
  results = []
40
  for news in news_list:
41
  if news.strip():
42
+ pred = prediction(news)
43
+ results.append(pred)
44
 
45
+ return pd.DataFrame({
46
+ 'News': news_list,
47
+ 'Prediction': results
48
+ })
49
 
50
 
51