test / app.py
adamtayzzz's picture
Upload 41 files
1076673
raw
history blame
2.39 kB
import gradio as gr
import requests
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
# from badnet_m import BadNet
import timm
# model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
# model.train()
# model = BadNet(3, 10)
# pipeline = pipeline.to('cuda:0')
import os
import logging
from transformers import WEIGHTS_NAME,AdamW,AlbertConfig,AlbertTokenizer,BertConfig,BertTokenizer
from pabee.modeling_albert import AlbertForSequenceClassification
from pabee.modeling_bert import BertForSequenceClassification
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors
import datasets
from whitebox_utils.classifier import MyClassifier
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import random
import numpy as np
import torch
import argparse
def random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
logger = logging.getLogger(__name__)
# TODO: dataset model tokenizer etc.
best_model_path = {
'albert_STS-B':'./outputs/train/albert/SST-2/checkpoint-7500',
}
MODEL_CLASSES = {
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
}
model = 'albert'
dataset = 'STS-B'
task_name = f'{dataset}'.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]() # transformers package-preprocessor
output_mode = output_modes[task_name] # output type
label_list = processor.get_labels()
num_labels = len(label_list)
output_dir = f'./PABEE/outputs/train/{model}/{dataset}'
data_dir = f'./PABEE/glue_data/{dataset}'
config_class, model_class, tokenizer_class = MODEL_CLASSES[model]
tokenizer = tokenizer_class.from_pretrained(output_dir, do_lower_case=True)
model = model_class.from_pretrained(best_model_path[f'{model}_{dataset}'])
exit_type='patience'
exit_value=3
classifier = MyClassifier(model,tokenizer,label_list,output_mode,exit_type,exit_value,model)
def greet(text,text2,exit_pos):
text_input = [(text,text2)]
classifier.get_prob_time(text_input,exit_position=exit_pos)
iface = gr.Interface(fn=greet, inputs='text', outputs="image")
iface.launch()