MatthiasPi's picture
commit WAR
ffd9d26
raw
history blame
9.33 kB
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import itertools
from WAR.Experiment_functions import display_phi
from WAR.dataset_handler import myData
class WAR:
def __init__(self,X_train,y_train,X_test,y_test,idx_lb,total_epoch_h,total_epoch_phi,batch_size_train,num_elem_queried
,phi,h,opti_phi,opti_h,second_query_strategy=None):
"""
device: device on which to train the model.
X_train: trainset.
Y_train: labels of the trainset
idx_lb: indices of the trainset that would be considered as labelled.
n_pool: length of the trainset.
total_epoch_h: number of epochs to train h.
total_epoch_phi: number of epochs to train phi.
batch_size_train: size of the batch in the training process.
num_elem_queried: number of elem queried each round.
phi: phi neural network.
h: h neural network.
opti_phi: phi optimizer.
opti_h: h optimizer.
cost: define the cost function for both neural network. "MSE" or MAE".
second_query_strategy: second strategy to assist our distribution-matching criterion.
"""
self.device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.X_train = X_train.to(self.device)
self.y_train = y_train.to(self.device)
self.X_test=X_test.to(self.device)
self.y_test=y_test.to(self.device)
self.idx_lb = idx_lb
self.n_pool = len(y_train)
self.total_epoch_h=total_epoch_h
self.total_epoch_phi=total_epoch_phi
self.batch_size_train=batch_size_train
self.num_elem_queried=num_elem_queried
self.phi=phi.to(self.device)
self.h=h.to(self.device)
self.opti_phi=opti_phi
self.opti_h=opti_h
self.cost="MSE"
self.second_query_strategy=second_query_strategy
#cost function used to train both phi and h
def cost_func(self,predicted,true):
if self.cost=="MSE":
return (predicted-true)**2
elif self.cost=="MAE":
return abs(predicted-true)
else:
raise Exception("invalid cost function")
def train(self,only_train=False,reduced=True,eta=3):# train function for one round
"""
only_train: activite when there is no more unlabelled data in the trainset. Will only train h and not train phi or query data.
reduced: will divide each query criterion by their standard deviation. In the case where they don't have the same amplitude, This will give them the same weight in the querying process. Irrelevant parameter if there is only one query criterion (self.second_query_strategy=None).
eta:factor used to rebalance the criteria. If >1, distribution matching criterion gets more weight than the other(s). Irrelevant parameter if there is only one query criterion.
"""
#recover loss
t1_descend=[]
t2_ascend=[]
# separating labelled and unlabelled data respectively
idx_lb_train = np.arange(self.n_pool)[self.idx_lb]
idx_ulb_train = np.arange(self.n_pool)[~self.idx_lb]
trainset_labelled=myData(self.X_train[idx_lb_train],self.y_train[idx_lb_train])
trainloader_labelled= DataLoader(trainset_labelled,shuffle=True,batch_size=self.batch_size_train)
for epoch in range(self.total_epoch_h):
for i,data in enumerate(trainloader_labelled,0):
label_x, label_y=data
self.opti_h.zero_grad()
# T1 (train h)
lb_out = self.h(label_x)
h_descent=torch.mean(self.cost_func(lb_out,label_y))
t1_descend.append(h_descent.detach().cpu())
h_descent.backward()
self.opti_h.step()
b_idxs=[]# batch of queried points
if not only_train:
#T2 (train phi)
# temporary set of labelled data indices. Used only to retrain phi during the time oracle has not been called.
#h is no retrained during this time.
idxs_temp=self.idx_lb.copy()
for elem_queried in range(self.num_elem_queried):
trainset_total=myData(self.X_train,self.y_train)
trainloader_total= DataLoader(trainset_total,shuffle=True,batch_size=len(trainset_total))
trainset_labelled=myData(self.X_train[idx_lb_train],self.y_train[idx_lb_train])
trainloader_labelled= DataLoader(trainset_labelled,shuffle=True,batch_size=self.batch_size_train)
for epoch in range(self.total_epoch_phi):
iterator_total_phi=itertools.cycle(trainloader_total)
iterator_labelled_phi=itertools.cycle(trainloader_labelled)
for i in range(len(trainloader_labelled)):
label_x,label_y = next(iterator_labelled_phi)
total_x,total_y = next(iterator_total_phi)
#display_phi(self.X_train,self.phi)
self.opti_phi.zero_grad()
phi_ascent = (torch.mean(self.phi(total_x))-torch.mean(self.phi(label_x)))
t2_ascend.append(phi_ascent.detach().cpu())
phi_ascent.backward()
self.opti_phi.step()
# Query process
b_queried=self.query(reduced,eta,idx_ulb_train)# query one element
idxs_temp[b_queried]=True #add it to the temporary set of labeled point indices indices
idx_ulb_train = np.arange(self.n_pool)[~idxs_temp] #update the set of unlabeled point indices
idx_lb_train = np.arange(self.n_pool)[idxs_temp] #update the set of labeled point indices
b_idxs.append(b_queried)#add the chosen point in the batch
self.idx_lb=idxs_temp#end of the query process: update the true set of labeled point indices indices
return t1_descend,t2_ascend,b_idxs
def query(self,reduced,eta,idx_ulb_train):# computing T3: query one point according to the chosen query criteria
"""
reduced:same as for function "train"
eta: sme as for function "train"
idx_ulb_train:indices of unlabeled points
"""
if self.second_query_strategy=="loss_approximation":
second_query_criterion = self.predict_loss(self.X_train[idx_ulb_train])
with torch.no_grad():
phi_scores = self.phi(self.X_train[idx_ulb_train]).view(-1)
if reduced and self.second_query_strategy!=None:
phi_scores_reduced=phi_scores/torch.std(phi_scores)
second_query_criterion_reduced=second_query_criterion/torch.std(second_query_criterion)
total_scores =-(eta*phi_scores_reduced+second_query_criterion_reduced )
elif self.second_query_strategy!=None:
total_scores =-(eta*phi_scores+second_query_criterion)
else:
total_scores =-eta*phi_scores
b=torch.argmin(total_scores)
return idx_ulb_train[b]
def predict_loss(self,X):# Second query criterion which act as loss estimator (uncertainty and diversity-based sampling)
"""
X: set of unlabeled elements of the trainset
"""
idxs_lb=np.arange(self.n_pool)[self.idx_lb]#get labeled data indices
losses=[]
with torch.no_grad():
for i in X:
idx_nearest_Xk,dist=self.Idx_NearestP(i,idxs_lb)
losses.append(self.Max_cost_B(idx_nearest_Xk,dist,i))
return torch.Tensor(losses).to(self.device)
def Idx_NearestP(self,Xu,idxs_lb):# Return the closest labeled point to the unlabeled point
"""
Xu:unlabeled point
idxs_lb: indices of labeled points
"""
distances=[]
for i in idxs_lb:
distances.append(torch.norm(Xu-self.X_train[i]))
return idxs_lb[distances.index(min(distances))],float(min(distances))
def Max_cost_B(self,idx_Xk,distance,i):#return the "maximum loss" of the unlabeled point
"""
idx_Xk: labeled point indice nearest to the unlabeled point
distance: distance between them
i:unlabeled point
"""
est_h_unl_X=self.h(i)
true_value_labelled_X=self.y_train[idx_Xk]
bound_min= true_value_labelled_X-distance
bound_max= true_value_labelled_X+distance
return max(self.cost_func(est_h_unl_X,bound_min),self.cost_func(est_h_unl_X,bound_max))[0]