File size: 9,333 Bytes
ffd9d26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import itertools

from WAR.Experiment_functions import display_phi
from WAR.dataset_handler import myData
  
        
class WAR:

    def __init__(self,X_train,y_train,X_test,y_test,idx_lb,total_epoch_h,total_epoch_phi,batch_size_train,num_elem_queried
                 ,phi,h,opti_phi,opti_h,second_query_strategy=None):
        
        """
        device: device on which to train the model.
        X_train: trainset.
        Y_train: labels of the trainset
        idx_lb: indices of the trainset that would be considered as labelled.
        n_pool: length of the trainset.
        total_epoch_h: number of epochs to train h.
        total_epoch_phi: number of epochs to train phi.
        batch_size_train: size of the batch in the training process.
        num_elem_queried: number of elem queried each round.
        phi: phi neural network.
        h: h neural network.
        opti_phi: phi optimizer.
        opti_h: h optimizer.
        cost: define the cost function for both neural network. "MSE" or MAE".
        second_query_strategy: second strategy to assist our distribution-matching criterion.
        """
        
        self.device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.X_train = X_train.to(self.device)
        self.y_train = y_train.to(self.device)
        self.X_test=X_test.to(self.device)
        self.y_test=y_test.to(self.device)
        self.idx_lb  = idx_lb
        self.n_pool  = len(y_train)
        self.total_epoch_h=total_epoch_h
        self.total_epoch_phi=total_epoch_phi
        self.batch_size_train=batch_size_train
        self.num_elem_queried=num_elem_queried
        self.phi=phi.to(self.device)
        self.h=h.to(self.device)
        self.opti_phi=opti_phi
        self.opti_h=opti_h
        self.cost="MSE"
        self.second_query_strategy=second_query_strategy
        
        
        
        
    #cost function used to train both phi and h
    def cost_func(self,predicted,true):
        if self.cost=="MSE":
            return (predicted-true)**2
        elif self.cost=="MAE":
            return abs(predicted-true)
        else:
            raise Exception("invalid cost function")
      
   

    
    def train(self,only_train=False,reduced=True,eta=3):# train function for one round
        
        """
        only_train: activite when there is no more unlabelled data in the trainset. Will only train h and not train phi or query data.
        reduced: will divide each query criterion by their standard deviation. In the case where they don't have the same amplitude, This will give them the same weight in the querying process. Irrelevant parameter if there is only one query criterion (self.second_query_strategy=None).
        eta:factor used to rebalance the criteria. If >1, distribution matching criterion gets more weight than the other(s). Irrelevant parameter if there is only one query criterion.
        
        """
        #recover loss 
        t1_descend=[]
        t2_ascend=[]
        
        # separating labelled and unlabelled data respectively
        idx_lb_train = np.arange(self.n_pool)[self.idx_lb]
        idx_ulb_train = np.arange(self.n_pool)[~self.idx_lb]
        
        
        trainset_labelled=myData(self.X_train[idx_lb_train],self.y_train[idx_lb_train])
        trainloader_labelled= DataLoader(trainset_labelled,shuffle=True,batch_size=self.batch_size_train)

        for epoch in range(self.total_epoch_h):
      
            for i,data in enumerate(trainloader_labelled,0):
                label_x, label_y=data
                self.opti_h.zero_grad()                 
                # T1 (train h)
                lb_out = self.h(label_x)
                h_descent=torch.mean(self.cost_func(lb_out,label_y))
                t1_descend.append(h_descent.detach().cpu())
                h_descent.backward()
                self.opti_h.step()
                
              
        b_idxs=[]# batch of queried points             
        if not only_train: 
            #T2 (train phi) 
                                 
            # temporary set of labelled data indices. Used only to retrain phi during the time oracle has not been called. 
            #h is no retrained during this time.
            idxs_temp=self.idx_lb.copy()
            
            for elem_queried in range(self.num_elem_queried):
                
                trainset_total=myData(self.X_train,self.y_train)
                trainloader_total= DataLoader(trainset_total,shuffle=True,batch_size=len(trainset_total))
                trainset_labelled=myData(self.X_train[idx_lb_train],self.y_train[idx_lb_train])
                trainloader_labelled= DataLoader(trainset_labelled,shuffle=True,batch_size=self.batch_size_train)
                for epoch in range(self.total_epoch_phi):  
                    iterator_total_phi=itertools.cycle(trainloader_total)
                    iterator_labelled_phi=itertools.cycle(trainloader_labelled)
                    for i in range(len(trainloader_labelled)):
                        label_x,label_y = next(iterator_labelled_phi)
                        total_x,total_y = next(iterator_total_phi)
                        #display_phi(self.X_train,self.phi)
                        self.opti_phi.zero_grad()
                        phi_ascent = (torch.mean(self.phi(total_x))-torch.mean(self.phi(label_x)))
                        t2_ascend.append(phi_ascent.detach().cpu())
                        phi_ascent.backward()
                        self.opti_phi.step()
                        
                # Query process
                b_queried=self.query(reduced,eta,idx_ulb_train)# query one element
                idxs_temp[b_queried]=True #add it to the temporary set of labeled point indices indices
                idx_ulb_train = np.arange(self.n_pool)[~idxs_temp] #update the set of unlabeled point indices
                idx_lb_train = np.arange(self.n_pool)[idxs_temp] #update the set of labeled point indices
                b_idxs.append(b_queried)#add the chosen point in the batch
            self.idx_lb=idxs_temp#end of the query process: update the true set of labeled point indices indices
        return t1_descend,t2_ascend,b_idxs

    
    
    def query(self,reduced,eta,idx_ulb_train):# computing T3: query one point according to the chosen query criteria
         
                          
        """
        reduced:same as for function "train"
        eta: sme as for function "train"
        idx_ulb_train:indices of unlabeled points
        
        """

        if self.second_query_strategy=="loss_approximation":
            second_query_criterion = self.predict_loss(self.X_train[idx_ulb_train])
                          
        with torch.no_grad():
            phi_scores =  self.phi(self.X_train[idx_ulb_train]).view(-1)

            if reduced and self.second_query_strategy!=None:
                phi_scores_reduced=phi_scores/torch.std(phi_scores)
                second_query_criterion_reduced=second_query_criterion/torch.std(second_query_criterion)
                total_scores =-(eta*phi_scores_reduced+second_query_criterion_reduced )  
                          
            elif self.second_query_strategy!=None:
                total_scores =-(eta*phi_scores+second_query_criterion)
                          
            else:
                total_scores =-eta*phi_scores
                
            b=torch.argmin(total_scores)
        
        return idx_ulb_train[b]
    
                          
    def predict_loss(self,X):# Second query criterion which act as loss estimator (uncertainty and diversity-based sampling)
        
        """
        X: set of unlabeled elements of the trainset
        
        """
                          
        idxs_lb=np.arange(self.n_pool)[self.idx_lb]#get labeled data indices
        losses=[]
        with torch.no_grad():
            for i in X:
                idx_nearest_Xk,dist=self.Idx_NearestP(i,idxs_lb) 
                losses.append(self.Max_cost_B(idx_nearest_Xk,dist,i))
        
        return torch.Tensor(losses).to(self.device)
    
    def Idx_NearestP(self,Xu,idxs_lb):# Return the closest labeled point to the unlabeled point 
        
                          
        """
        Xu:unlabeled point
        idxs_lb: indices of labeled points
        
        """
                          
        distances=[]
        for i in idxs_lb:
            distances.append(torch.norm(Xu-self.X_train[i]))
            
        return idxs_lb[distances.index(min(distances))],float(min(distances))
    

    
    def Max_cost_B(self,idx_Xk,distance,i):#return the "maximum loss" of the unlabeled point
       
        """
        idx_Xk: labeled point indice nearest to the unlabeled point
        distance: distance between them
        i:unlabeled point
        
        """
                          
        est_h_unl_X=self.h(i)
        true_value_labelled_X=self.y_train[idx_Xk]
        bound_min= true_value_labelled_X-distance
        bound_max= true_value_labelled_X+distance
        return max(self.cost_func(est_h_unl_X,bound_min),self.cost_func(est_h_unl_X,bound_max))[0]