Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
# ๋ชจ๋ธ ์ค๋นํ๊ธฐ | |
from transformers import RobertaForSequenceClassification, AutoTokenizer | |
import numpy as np | |
import pandas as pd | |
import torch | |
import os | |
# [theme] | |
# base="dark" | |
# primaryColor="purple" | |
# ์ ๋ชฉ ์ ๋ ฅ | |
st.header('ํ๊ตญํ์ค์ฐ์ ๋ถ๋ฅ ์๋์ฝ๋ฉ ์๋น์ค') | |
# ์ฌ๋ก๋ ์ํ๋๋ก | |
def md_loading(): | |
## cpu | |
# device = torch.device('cpu') | |
tokenizer = AutoTokenizer.from_pretrained('klue/roberta-base') | |
model = RobertaForSequenceClassification.from_pretrained('klue/roberta-base', num_labels=495) | |
model_checkpoint = 'upsampling_20.bin' | |
project_path = './' | |
output_model_file = os.path.join(project_path, model_checkpoint) | |
model.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu'))) | |
label_tbl = np.load('./label_table.npy') | |
loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8') | |
print('ready') | |
return tokenizer, model, label_tbl, loc_tbl | |
# ๋ชจ๋ธ ๋ก๋ | |
tokenizer, model, label_tbl, loc_tbl = md_loading() | |
# ํ ์คํธ input ๋ฐ์ค | |
business = st.text_input('์ฌ์ ์ฒด๋ช ').replace(',', '') | |
business_work = st.text_input('์ฌ์ ์ฒด ํ๋์ผ').replace(',', '') | |
work_department = st.text_input('๊ทผ๋ฌด๋ถ์').replace(',', '') | |
work_position = st.text_input('์ง์ฑ ').replace(',', '') | |
what_do_i = st.text_input('๋ด๊ฐ ํ๋ ์ผ').replace(',', '') | |
# md_input: ๋ชจ๋ธ์ ์ ๋ ฅํ input ๊ฐ ์ ์ | |
md_input = ', '.join([business, business_work, work_department, work_position, what_do_i]) | |
## ์์ ํ์ธ | |
# st.write(md_input) | |
# ๋ฒํผ | |
if st.button('ํ์ธ'): | |
## ๋ฒํผ ํด๋ฆญ ์ ์ํ์ฌํญ | |
### ๋ชจ๋ธ ์คํ | |
query_tokens = md_input.split(',') | |
input_ids = np.zeros(shape=[1, 64]) | |
attention_mask = np.zeros(shape=[1, 64]) | |
seq = '[CLS] ' | |
try: | |
for i in range(5): | |
seq += query_tokens[i] + ' ' | |
except: | |
None | |
tokens = tokenizer.tokenize(seq) | |
ids = tokenizer.convert_tokens_to_ids(tokens) | |
length = len(ids) | |
if length > 64: | |
length = 64 | |
for i in range(length): | |
input_ids[0, i] = ids[i] | |
attention_mask[0, i] = 1 | |
input_ids = torch.from_numpy(input_ids).type(torch.long) | |
attention_mask = torch.from_numpy(attention_mask).type(torch.long) | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=None) | |
logits = outputs.logits | |
# # ๋จ๋ ์์ธก ์ | |
# arg_idx = torch.argmax(logits, dim=1) | |
# print('arg_idx:', arg_idx) | |
# num_ans = label_tbl[arg_idx] | |
# str_ans = loc_tbl['ํญ๋ชฉ๋ช '][loc_tbl['์ฝ๋'] == num_ans].values | |
# ์์ k๋ฒ์งธ๊น์ง ์์ธก ์ | |
k = 10 | |
topk_idx = torch.topk(logits.flatten(), k).indices | |
num_ans_topk = label_tbl[topk_idx] | |
str_ans_topk = [loc_tbl['ํญ๋ชฉ๋ช '][loc_tbl['์ฝ๋'] == k] for k in num_ans_topk] | |
# print(num_ans, str_ans) | |
# print(num_ans_topk) | |
# print('์ฌ์ ์ฒด๋ช :', query_tokens[0]) | |
# print('์ฌ์ ์ฒด ํ๋์ผ:', query_tokens[1]) | |
# print('๊ทผ๋ฌด๋ถ์:', query_tokens[2]) | |
# print('์ง์ฑ :', query_tokens[3]) | |
# print('๋ด๊ฐ ํ๋์ผ:', query_tokens[4]) | |
# print('์ฐ์ ์ฝ๋ ๋ฐ ๋ถ๋ฅ:', num_ans, str_ans) | |
# ans = '' | |
# ans1, ans2, ans3 = '', '', '' | |
## ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ ์ถ๋ ฅ | |
# st.write("์ฐ์ ์ฝ๋ ๋ฐ ๋ถ๋ฅ:", num_ans, str_ans[0]) | |
# st.write("์ธ๋ถ๋ฅ ์ฝ๋") | |
# for i in range(k): | |
# st.write(str(i+1) + '์์:', num_ans_topk[i], str_ans_topk[i].iloc[0]) | |
# print(num_ans) | |
# print(str_ans, type(str_ans)) | |
str_ans_topk_list = [] | |
for i in range(k): | |
str_ans_topk_list.append(str_ans_topk[i].iloc[0]) | |
# print(str_ans_topk_list) | |
ans_topk_df = pd.DataFrame({ | |
'NO': range(1, k+1), | |
'์ธ๋ถ๋ฅ ์ฝ๋': num_ans_topk, | |
'์ธ๋ถ๋ฅ ๋ช ์นญ': str_ans_topk_list | |
}) | |
ans_topk_df = ans_topk_df.set_index('NO') | |
st.dataframe(ans_topk_df) |