Spaces:
Runtime error
Runtime error
""" | |
M6_NB_MiniProject_1_PartB_Deploy_Medical_Q&A_GPT2.ipynb | |
Original file in google drive | |
""" | |
import os | |
import gradio as gr | |
import torch | |
import transformers | |
from transformers import AutoModelWithLMHead, AutoTokenizer | |
def generate_response(model, tokenizer, prompt, max_length=512): | |
# YOUR CODE HERE ... | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") # 'pt' for returning pytorch tensor | |
# Create the attention mask and pad token id | |
attention_mask = torch.ones_like(input_ids) | |
pad_token_id = tokenizer.eos_token_id | |
output = model.generate( | |
input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
attention_mask=attention_mask, | |
pad_token_id=pad_token_id | |
) | |
return tokenizer.decode(output[0], skip_special_tokens=True) | |
def generate_query_response(prompt, max_length=200): | |
# Load your model from hub | |
model = AutoModelWithLMHead.from_pretrained("vanim/chatgpt2-medical-QnA") | |
# Load your tokenizer from hub | |
tokenizer = AutoTokenizer.from_pretrained("vanim/chatgpt2-medical-QnA") | |
return generate_response(model, tokenizer, prompt, max_length) | |
# Create title, description and article strings | |
title = "Medical QnA chat bot" | |
description = "ChatGPT2 based Medical Q and A demo" | |
demo = gr.Interface(fn=generate_query_response, | |
inputs = "text", | |
outputs = "text", | |
title=title, | |
description=description,) | |
#Launch the demo | |
demo.launch() | |