vanim's picture
import relevant packages
910ef46
"""
M6_NB_MiniProject_1_PartB_Deploy_Medical_Q&A_GPT2.ipynb
Original file in google drive
"""
import os
import gradio as gr
import torch
import transformers
from transformers import AutoModelWithLMHead, AutoTokenizer
def generate_response(model, tokenizer, prompt, max_length=512):
# YOUR CODE HERE ...
input_ids = tokenizer.encode(prompt, return_tensors="pt") # 'pt' for returning pytorch tensor
# Create the attention mask and pad token id
attention_mask = torch.ones_like(input_ids)
pad_token_id = tokenizer.eos_token_id
output = model.generate(
input_ids,
max_length=max_length,
num_return_sequences=1,
attention_mask=attention_mask,
pad_token_id=pad_token_id
)
return tokenizer.decode(output[0], skip_special_tokens=True)
def generate_query_response(prompt, max_length=200):
# Load your model from hub
model = AutoModelWithLMHead.from_pretrained("vanim/chatgpt2-medical-QnA")
# Load your tokenizer from hub
tokenizer = AutoTokenizer.from_pretrained("vanim/chatgpt2-medical-QnA")
return generate_response(model, tokenizer, prompt, max_length)
# Create title, description and article strings
title = "Medical QnA chat bot"
description = "ChatGPT2 based Medical Q and A demo"
demo = gr.Interface(fn=generate_query_response,
inputs = "text",
outputs = "text",
title=title,
description=description,)
#Launch the demo
demo.launch()