Spaces:
Runtime error
Runtime error
if __name__ == '__main__': | |
inputs = ['gbjjhbdjhbdgjhdbfjhsdkjrkjf', 'fdjhbjhsbd'] | |
from transformers import AutoTokenizer | |
from model import CustomModel | |
import torch | |
from configuration import CFG | |
from dataset import SingleInputDataset | |
from torch.utils.data import DataLoader | |
from utils import inference_fn, get_char_probs, get_results, get_text | |
import numpy as np | |
import gradio as gr | |
import os | |
device = torch.device('cpu') | |
config_path = os.path.join('models_file', 'config.pth') | |
model_path = os.path.join('models_file', 'microsoft-deberta-base_0.9449373420387531_8_best.pth') | |
tokenizer = AutoTokenizer.from_pretrained('models_file/tokenizer') | |
model = CustomModel(CFG, config_path=config_path, pretrained=False) | |
state = torch.load(model_path, | |
map_location=device) | |
model.load_state_dict(state['model']) | |
def get_answer(context, feature): | |
## Input to the model using patient-history and feature-text | |
inputs_single = tokenizer(context, feature, | |
add_special_tokens=True, | |
max_length=CFG.max_len, | |
padding="max_length", | |
return_offsets_mapping=False) | |
for k, v in inputs_single.items(): | |
inputs_single[k] = torch.tensor(v, dtype=torch.long) | |
# Create a new dataset containing only the input sample | |
single_input_dataset = SingleInputDataset(inputs_single) | |
# Create a DataLoader for the new dataset | |
single_input_loader = DataLoader( | |
single_input_dataset, | |
batch_size=1, | |
shuffle=False, | |
num_workers=2 | |
) | |
# Perform inference on the single input | |
output = inference_fn(single_input_loader, model, device) | |
prediction = output.reshape((1, CFG.max_len)) | |
char_probs = get_char_probs([context], prediction, tokenizer) | |
predictions = np.mean([char_probs], axis=0) | |
results = get_results(predictions, th=0.5) | |
print(results) | |
return get_text(context, results[0]) | |
inputs = [gr.inputs.Textbox(label="Context Para", lines=10), gr.inputs.Textbox(label="Question", lines=1)] | |
output = gr.outputs.Textbox(label="Answer") | |
article = "<p style='text-align: center'><a href='https://www.xelpmoc.in/' target='_blank'>Made by Xelpmoc</a></p>" | |
app = gr.Interface( | |
fn=get_answer, | |
inputs=inputs, | |
outputs=output, | |
allow_flagging='never', | |
title="Phrase Extraction", | |
article=article, | |
enable_queue=True, | |
cache_examples=False, | |
css="footer {visibility: hidden}" | |
) | |
app.launch() | |