import pickle import numpy as np import torch from torch.utils.data import Dataset as TorchDataset from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef from peft import PeftModel, get_peft_config, PeftConfig, get_peft_model, LoraConfig, TaskType from accelerate import Accelerator from tqdm import tqdm # Initialize the Accelerator accelerator = Accelerator() class ProteinDataset(TorchDataset): def __init__(self, sequences_path, labels_path, tokenizer, max_length): self.tokenizer = tokenizer self.max_length = max_length with open(sequences_path, "rb") as f: self.sequences = pickle.load(f) with open(labels_path, "rb") as f: self.labels = pickle.load(f) def __len__(self): return len(self.sequences) def __getitem__(self, idx): sequence = self.sequences[idx] label = self.labels[idx] tokenized = self.tokenizer(sequence, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False) # Remove the extra batch dimension for key in tokenized: tokenized[key] = tokenized[key].squeeze(0) # Ensure labels are also padded/truncated to match tokenized input label_padded = [-100] * self.max_length # Using -100 as the ignore index label_padded[:len(label)] = label[:self.max_length] tokenized["labels"] = torch.tensor(label_padded) return tokenized def compute_metrics(p): predictions, labels = p.predictions, p.label_ids predictions = np.argmax(predictions, axis=2) mask = labels != -100 predictions = predictions[mask].flatten() labels = labels[mask].flatten() accuracy = accuracy_score(labels, predictions) precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary') auc = roc_auc_score(labels, predictions) mcc = matthews_corrcoef(labels, predictions) return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc} def evaluate_in_chunks(dataset, trainer, chunk_percentage=0.2): chunk_size = int(len(dataset) * chunk_percentage) all_results = [] # Wrap the loop with tqdm for progress bar for i in tqdm(range(0, len(dataset), chunk_size), desc="Evaluating chunks"): chunk = [dataset[j] for j in range(i, min(i + chunk_size, len(dataset)))] chunk_results = trainer.evaluate(chunk) print(f"Results for chunk starting at index {i}: {chunk_results}") # Save the chunk results to disk with open(f"results_chunk_{i}.pkl", "wb") as f: pickle.dump(chunk_results, f) all_results.append(chunk_results) return all_results def aggregate_results(results_list): total_samples = sum([res["eval_samples"] for res in results_list]) aggregated_results = {} for key in results_list[0].keys(): if key == "eval_samples": continue aggregated_results[key] = sum([res[key] * res["eval_samples"] for res in results_list]) / total_samples return aggregated_results # Initialize tokenizer and datasets tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") train_dataset = ProteinDataset("data/12M_data/512_train_sequences_chunked_by_family.pkl", "data/12M_data/512_train_labels_chunked_by_family.pkl", tokenizer, 512) test_dataset = ProteinDataset("data/12M_data/512_test_sequences_chunked_by_family.pkl", "data/12M_data/512_test_labels_chunked_by_family.pkl", tokenizer, 512) # Load the pre-trained LoRA model base_model_path = "facebook/esm2_t33_650M_UR50D" lora_model_path = "qlora_binding_sites/best_model_esm2_t33_650M_qlora_2023-10-18_02-14-48" base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) model = PeftModel.from_pretrained(base_model, lora_model_path) model = accelerator.prepare(model) # Initialize the Trainer trainer = Trainer( model=model, compute_metrics=compute_metrics ) Evaluate the model on chunks of the training dataset train_results = evaluate_in_chunks(train_dataset, trainer) aggregated_train_results = aggregate_results(train_results) print(f"Aggregated Training Results: {aggregated_train_results}") # Evaluate the model on chunks of the test dataset test_results = evaluate_in_chunks(test_dataset, trainer) aggregated_test_results = aggregate_results(test_results) print(f"Aggregated Test Results: {aggregated_test_results}")