Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -85,7 +85,7 @@ def load_model(model_name: str):
|
|
85 |
return model, None
|
86 |
|
87 |
else:
|
88 |
-
|
89 |
|
90 |
|
91 |
|
@@ -112,24 +112,29 @@ def analyze_dna(username, password, sequence, model_name):
|
|
112 |
model, tokenizer = load_model(model_name)
|
113 |
|
114 |
def get_logits(seq, model_name):
|
|
|
115 |
if model_name == 'gena-bert':
|
|
|
116 |
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
|
117 |
with torch.no_grad():
|
118 |
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
119 |
return logits
|
120 |
|
121 |
elif model_name == 'CNN':
|
122 |
-
|
123 |
SEQUENCE_LENGTH = 8000
|
|
|
|
|
|
|
124 |
seq = seq[:SEQUENCE_LENGTH]
|
125 |
|
126 |
# Pad sequences to the desired length
|
127 |
-
seq = seq.ljust(
|
128 |
|
129 |
-
# Apply one-hot encoding to the
|
130 |
-
|
131 |
with torch.no_grad():
|
132 |
-
logits = model(
|
133 |
return logits
|
134 |
|
135 |
|
|
|
85 |
return model, None
|
86 |
|
87 |
else:
|
88 |
+
raise ValueError("Invalid model name")
|
89 |
|
90 |
|
91 |
|
|
|
112 |
model, tokenizer = load_model(model_name)
|
113 |
|
114 |
def get_logits(seq, model_name):
|
115 |
+
|
116 |
if model_name == 'gena-bert':
|
117 |
+
|
118 |
inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=512, return_tensors="pt", return_token_type_ids=False)
|
119 |
with torch.no_grad():
|
120 |
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
121 |
return logits
|
122 |
|
123 |
elif model_name == 'CNN':
|
124 |
+
|
125 |
SEQUENCE_LENGTH = 8000
|
126 |
+
pad_char = 'N'
|
127 |
+
|
128 |
+
# Truncate sequence
|
129 |
seq = seq[:SEQUENCE_LENGTH]
|
130 |
|
131 |
# Pad sequences to the desired length
|
132 |
+
seq = seq.ljust(SEQUENCE_LENGTH, pad_char)[:SEQUENCE_LENGTH]
|
133 |
|
134 |
+
# Apply one-hot encoding to the sequence
|
135 |
+
input_tensor = one_hot_encode(seq).unsqueeze(0)
|
136 |
with torch.no_grad():
|
137 |
+
logits = model(input_tensor)
|
138 |
return logits
|
139 |
|
140 |
|