NOOTestspace / app.py
mawairon's picture
Update app.py
ace289f verified
raw
history blame
6.53 kB
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import os
import huggingface_hub
from huggingface_hub import hf_hub_download, login
import model_archs
from model_archs import BertClassifier, LogisticRegressionTorch, SimpleCNN, MLP, Pool2BN
import tangermeme
from tangermeme import one_hot_encode
# Load label mapping
label_to_int = pd.read_pickle('label_to_int.pkl')
int_to_label = {v: k for k, v in label_to_int.items()}
# Update labels based on the given conditions
for k, v in int_to_label.items():
if "KOREA" in v:
int_to_label[k] = "KOREA"
elif "KINGDOM" in v:
int_to_label[k] = "UK"
elif "RUSSIAN" in v:
int_to_label[k] = "RUSSIA"
def load_model(model_name: str):
metadata_features = 0
N_UNIQUE_CLASSES = 38
if model_name == 'gena-bert':
base_model = AutoModel.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base-lastln-t2t', trust_remote_code=True)
input_size = 768 + metadata_features
log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
token = os.getenv('HUGGINGFACE_TOKEN')
if token is None:
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
login(token=token)
file_path = hf_hub_download(
repo_id="mawairon/noo_test",
filename="gena-blastln-bs33-lr4e-05-S168.pth",
use_auth_token=token
)
weights = torch.load(file_path, map_location=torch.device('cpu'))
base_model.load_state_dict(weights['model_state_dict'])
log_reg.load_state_dict(weights['log_reg_state_dict'])
model = BertClassifier(base_model, log_reg, num_labels=N_UNIQUE_CLASSES)
model.eval()
return model, tokenizer
elif model_name == 'CNN':
hidden_dim = 2048
width = 2048
seq_drop_prob = 0.05
train_sequence_length = 8000
weight_decay = 0.0001
num_labs = len(set(y_train))
model_seq = SimpleCNN(18, hidden_dim, additional_layer=False)
new_head = torch.nn.Sequential(
torch.nn.Dropout(0.5),
MLP([hidden_dim*2 , num_labs])
)
model = torch.nn.Sequential(
model_seq,
new_head
)
return model, None
else:
return {"error": "Invalid model name"}
def analyze_dna(username, password, sequence, model_name):
valid_usernames = os.getenv('USERNAME').split(',')
env_password = os.getenv('PASSWORD')
if username not in valid_usernames or password != env_password:
return {"error": "Invalid username or password"}, ""
try:
# Remove all whitespace characters
sequence = sequence.replace(" ", "").replace("\n", "").replace("\t", "").replace("\r", "")
if not all(nucleotide in 'ACTGN' for nucleotide in sequence):
return {"error": "Sequence contains invalid characters"}, ""
if len(sequence) < 300:
return {"error": "Sequence needs to be at least 300 nucleotides long"}, ""
model, tokenizer = load_model(model_name)
def get_logits(seq, model_name):
if model_name == 'gena-bert':
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
with torch.no_grad():
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
return logits
elif model_name == 'CNN':
# Truncate sequence
SEQUENCE_LENGTH = 8000
seq = seq[:SEQUENCE_LENGTH]
# Pad sequences to the desired length
seq = seq.ljust(length, pad_char)[:SEQUENCE_LENGTH]
# Apply one-hot encoding to the 'sequence' column
input = seq.one_hot_encode()
with torch.no_grad():
logits = model(input)
return logits
# if (len(sequence) > 3000 and model_name == 'gena-bert') or (len(sequence) > 10000 and model_name == 'CNN'):
# num_shifts = len(sequence) // 1000
# logits_sum = None
# for i in range(num_shifts):
# shifted_sequence = sequence[i*1000:] + sequence[:i*1000]
# logits = get_logits(shifted_sequence)
# if logits_sum is None:
# logits_sum = logits
# else:
# logits_sum += logits
# logits_avg = logits_sum / num_shifts
# else:
logits_avg = get_logits(sequence)
probabilities = torch.nn.functional.softmax(logits_avg, dim=-1).squeeze().tolist()
top_5_indices = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:5]
top_5_probs = [probabilities[i] for i in top_5_indices]
top_5_labels = [int_to_label[i] for i in top_5_indices]
result = [(label, prob) for label, prob in zip(top_5_labels, top_5_probs)]
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(top_5_labels, top_5_probs, color='skyblue')
ax.set_xlabel('Probability')
ax.set_title('Assuming this sequence was genetically engineered,\n the 5 most likely countries in which it was engineered are:')
plt.gca().invert_yaxis()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
return result, f'<img src="data:image/png;base64,{image_base64}" />'
except Exception as e:
return {"error": str(e)}, ""
# Create a Gradio interface
demo = gr.Interface(
fn=analyze_dna,
inputs=[
gr.Textbox(label="Username"),
gr.Textbox(label="Password", type="password"),
gr.Textbox(label="DNA Sequence"),
gr.Dropdown(label="Model", choices=[
"gena-bert",
"CNN"
])
],
outputs=["json", "HTML"]
)
# Launch the interface
demo.launch()