import numpy as np import torch import torch.optim as optim from torch.utils.data import DataLoader import itertools from WAR.Experiment_functions import display_phi from WAR.dataset_handler import myData class WAR: def __init__(self,X_train,y_train,X_test,y_test,idx_lb,total_epoch_h,total_epoch_phi,batch_size_train,num_elem_queried ,phi,h,opti_phi,opti_h,second_query_strategy=None): """ device: device on which to train the model. X_train: trainset. Y_train: labels of the trainset idx_lb: indices of the trainset that would be considered as labelled. n_pool: length of the trainset. total_epoch_h: number of epochs to train h. total_epoch_phi: number of epochs to train phi. batch_size_train: size of the batch in the training process. num_elem_queried: number of elem queried each round. phi: phi neural network. h: h neural network. opti_phi: phi optimizer. opti_h: h optimizer. cost: define the cost function for both neural network. "MSE" or MAE". second_query_strategy: second strategy to assist our distribution-matching criterion. """ self.device=torch.device("cuda" if torch.cuda.is_available() else "cpu") self.X_train = X_train.to(self.device) self.y_train = y_train.to(self.device) self.X_test=X_test.to(self.device) self.y_test=y_test.to(self.device) self.idx_lb = idx_lb self.n_pool = len(y_train) self.total_epoch_h=total_epoch_h self.total_epoch_phi=total_epoch_phi self.batch_size_train=batch_size_train self.num_elem_queried=num_elem_queried self.phi=phi.to(self.device) self.h=h.to(self.device) self.opti_phi=opti_phi self.opti_h=opti_h self.cost="MSE" self.second_query_strategy=second_query_strategy #cost function used to train both phi and h def cost_func(self,predicted,true): if self.cost=="MSE": return (predicted-true)**2 elif self.cost=="MAE": return abs(predicted-true) else: raise Exception("invalid cost function") def train(self,only_train=False,reduced=True,eta=3):# train function for one round """ only_train: activite when there is no more unlabelled data in the trainset. Will only train h and not train phi or query data. reduced: will divide each query criterion by their standard deviation. In the case where they don't have the same amplitude, This will give them the same weight in the querying process. Irrelevant parameter if there is only one query criterion (self.second_query_strategy=None). eta:factor used to rebalance the criteria. If >1, distribution matching criterion gets more weight than the other(s). Irrelevant parameter if there is only one query criterion. """ #recover loss t1_descend=[] t2_ascend=[] # separating labelled and unlabelled data respectively idx_lb_train = np.arange(self.n_pool)[self.idx_lb] idx_ulb_train = np.arange(self.n_pool)[~self.idx_lb] trainset_labelled=myData(self.X_train[idx_lb_train],self.y_train[idx_lb_train]) trainloader_labelled= DataLoader(trainset_labelled,shuffle=True,batch_size=self.batch_size_train) for epoch in range(self.total_epoch_h): for i,data in enumerate(trainloader_labelled,0): label_x, label_y=data self.opti_h.zero_grad() # T1 (train h) lb_out = self.h(label_x) h_descent=torch.mean(self.cost_func(lb_out,label_y)) t1_descend.append(h_descent.detach().cpu()) h_descent.backward() self.opti_h.step() b_idxs=[]# batch of queried points if not only_train: #T2 (train phi) # temporary set of labelled data indices. Used only to retrain phi during the time oracle has not been called. #h is no retrained during this time. idxs_temp=self.idx_lb.copy() for elem_queried in range(self.num_elem_queried): trainset_total=myData(self.X_train,self.y_train) trainloader_total= DataLoader(trainset_total,shuffle=True,batch_size=len(trainset_total)) trainset_labelled=myData(self.X_train[idx_lb_train],self.y_train[idx_lb_train]) trainloader_labelled= DataLoader(trainset_labelled,shuffle=True,batch_size=self.batch_size_train) for epoch in range(self.total_epoch_phi): iterator_total_phi=itertools.cycle(trainloader_total) iterator_labelled_phi=itertools.cycle(trainloader_labelled) for i in range(len(trainloader_labelled)): label_x,label_y = next(iterator_labelled_phi) total_x,total_y = next(iterator_total_phi) #display_phi(self.X_train,self.phi) self.opti_phi.zero_grad() phi_ascent = (torch.mean(self.phi(total_x))-torch.mean(self.phi(label_x))) t2_ascend.append(phi_ascent.detach().cpu()) phi_ascent.backward() self.opti_phi.step() # Query process b_queried=self.query(reduced,eta,idx_ulb_train)# query one element idxs_temp[b_queried]=True #add it to the temporary set of labeled point indices indices idx_ulb_train = np.arange(self.n_pool)[~idxs_temp] #update the set of unlabeled point indices idx_lb_train = np.arange(self.n_pool)[idxs_temp] #update the set of labeled point indices b_idxs.append(b_queried)#add the chosen point in the batch self.idx_lb=idxs_temp#end of the query process: update the true set of labeled point indices indices return t1_descend,t2_ascend,b_idxs def query(self,reduced,eta,idx_ulb_train):# computing T3: query one point according to the chosen query criteria """ reduced:same as for function "train" eta: sme as for function "train" idx_ulb_train:indices of unlabeled points """ if self.second_query_strategy=="loss_approximation": second_query_criterion = self.predict_loss(self.X_train[idx_ulb_train]) with torch.no_grad(): phi_scores = self.phi(self.X_train[idx_ulb_train]).view(-1) if reduced and self.second_query_strategy!=None: phi_scores_reduced=phi_scores/torch.std(phi_scores) second_query_criterion_reduced=second_query_criterion/torch.std(second_query_criterion) total_scores =-(eta*phi_scores_reduced+second_query_criterion_reduced ) elif self.second_query_strategy!=None: total_scores =-(eta*phi_scores+second_query_criterion) else: total_scores =-eta*phi_scores b=torch.argmin(total_scores) return idx_ulb_train[b] def predict_loss(self,X):# Second query criterion which act as loss estimator (uncertainty and diversity-based sampling) """ X: set of unlabeled elements of the trainset """ idxs_lb=np.arange(self.n_pool)[self.idx_lb]#get labeled data indices losses=[] with torch.no_grad(): for i in X: idx_nearest_Xk,dist=self.Idx_NearestP(i,idxs_lb) losses.append(self.Max_cost_B(idx_nearest_Xk,dist,i)) return torch.Tensor(losses).to(self.device) def Idx_NearestP(self,Xu,idxs_lb):# Return the closest labeled point to the unlabeled point """ Xu:unlabeled point idxs_lb: indices of labeled points """ distances=[] for i in idxs_lb: distances.append(torch.norm(Xu-self.X_train[i])) return idxs_lb[distances.index(min(distances))],float(min(distances)) def Max_cost_B(self,idx_Xk,distance,i):#return the "maximum loss" of the unlabeled point """ idx_Xk: labeled point indice nearest to the unlabeled point distance: distance between them i:unlabeled point """ est_h_unl_X=self.h(i) true_value_labelled_X=self.y_train[idx_Xk] bound_min= true_value_labelled_X-distance bound_max= true_value_labelled_X+distance return max(self.cost_func(est_h_unl_X,bound_min),self.cost_func(est_h_unl_X,bound_max))[0]