import numpy as np import warnings import torch.optim as optim import torch from WAR.Models import NN_phi,NN_h_RELU from WAR.Experiment_functions import * def full_training(strategy,num_round,show_losses,show_chosen_each_round, reset_phi,reset_h,weight_decay,lr_h=None,lr_phi=None,reduced=False,eta=1 ): """ strategy: an object of class WAR num_round: total number of query rounds show_losses: display graphs showing the loss of h and phi each rounds show_chosen_each_round:display a graph showing the data queried each round reset_phi: if True, the phi neural network is reset after each round. can avoir overfitting but increase the number of epochs required to train the model reset_h:if True, the h neural network is reset after each round. can avoir overfitting but increase the number of epochs required to train the model lr_h: learning rate of h lr_phi: learning rate of phi reduced: will divide each query criterion by their standard deviation. In the case where they don't have the same amplitude, This will give them the same weight in the querying process. Irrelevant parameter if there is only one query criterion eta:factor used to rebalance the criteria. If >1, distribution matching criterion gets more weight than the other(s). Irrelevant parameter if there is only one query criterion. """ t1_descend_list=[] t2_ascend_list=[] acc = []# MAE acc_percentage=[] #MAPE acc_rmse=[] #RMSE only_train=False for rd in range(1,num_round+1): print('\n================Round {:d}==============='.format(rd)) # if not enough unlabelled data to query a full batch, we will query the remaining data if len(np.arange(strategy.n_pool)[~strategy.idx_lb])<=strategy.num_elem_queried: only_train=True #reset neural networks if reset_phi==True: strategy.phi=NN_phi(dim_input=strategy.X_train.shape[1]) strategy.opti_phi = optim.Adam(strategy.phi.parameters(), lr=lr_phi,maximize=True) if reset_h==True: strategy.h=NN_h_RELU(dim_input=strategy.X_train.shape[1]) strategy.opti_h = optim.Adam(strategy.h.parameters(), lr=lr_h,weight_decay=weight_decay) t1,t2,b_idxs=strategy.train(only_train,reduced,eta) t1_descend_list.append(t1) t2_ascend_list.append(t2) if only_train==True: strategy.idx_lb[:]= True else: strategy.idx_lb[b_idxs] = True #"simulation" of the oracle who label the queried samples with torch.no_grad(): if show_losses: display_loss_t1(t1,rd) display_loss_t2(t2,rd) if show_chosen_each_round: if strategy.X_train.shape[1]==1: #display_phi(strategy.X_train,strategy.phi,rd) display_chosen_labelled_datas(strategy.X_train.cpu(),strategy.idx_lb,strategy.y_train.cpu(),b_idxs,rd) #display_prediction(strategy.X_test,strategy.h,strategy.y_test,rd) else: display_chosen_labelled_datas_PCA(strategy.X_train.cpu(),strategy.idx_lb,strategy.y_train.cpu(),b_idxs,rd) acc_rmse.append(RMSE(strategy.X_test,strategy.y_test,strategy.h).cpu()) acc.append(MAE(strategy.X_test,strategy.y_test,strategy.h).cpu()) acc_percentage.append(MAPE(strategy.X_test,strategy.y_test,strategy.h).cpu()) print('\n================Final training===============') t1,t2,_=strategy.train(only_train,reduced,eta) t1_descend_list.append(t1) t2_ascend_list.append(t2) with torch.no_grad(): #display_loss_t1(t1,rd) #display_prediction(strategy.X_test,strategy.h,strategy.y_test,"final") acc.append(MAE(strategy.X_test,strategy.y_test,strategy.h).cpu()) acc_percentage.append(MAPE(strategy.X_test,strategy.y_test,strategy.h).cpu()) acc_rmse.append(RMSE(strategy.X_test,strategy.y_test,strategy.h).cpu()) return acc,acc_percentage, acc_rmse,t1_descend_list,t2_ascend_list def check_num_round(num_round,len_dataset,nb_initial_labelled_datas,num_elem_queried): max_round=int(np.ceil((len_dataset-nb_initial_labelled_datas)/num_elem_queried)) if num_round>max_round: warnings.warn(f"when querying {num_elem_queried} data per round, num_rounds={num_round} is exceeding"+ f" the maximum number of rounds (total data queried superior to number of initial unlabelled data).\nnum_round set to {max_round}") num_round=max_round return num_round