makitanikaze commited on
Commit
dcb67a0
1 Parent(s): 1b4a1f4

Delete pretrain_model.py

Browse files
Files changed (1) hide show
  1. pretrain_model.py +0 -133
pretrain_model.py DELETED
@@ -1,133 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
-
6
- from modeling_p5 import P5
7
-
8
- class P5Pretraining(P5):
9
- def __init__(self, config):
10
- super().__init__(config)
11
-
12
- self.losses = self.config.losses.split(',')
13
-
14
- def train_step(self, batch):
15
-
16
- device = next(self.parameters()).device
17
- input_ids = batch['input_ids'].to(device)
18
- whole_word_ids = batch['whole_word_ids'].to(device)
19
-
20
- lm_labels = batch["target_ids"].to(device)
21
-
22
- loss_weights = batch["loss_weights"].to(device)
23
-
24
- output = self(
25
- input_ids=input_ids,
26
- whole_word_ids=whole_word_ids,
27
- labels=lm_labels,
28
- return_dict=True
29
- )
30
- assert 'loss' in output
31
-
32
- lm_mask = lm_labels != -100
33
- lm_mask = lm_mask.float()
34
- B, L = lm_labels.size()
35
-
36
- loss = output['loss']
37
-
38
- loss = loss.view(B, L) * lm_mask
39
-
40
- loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1)
41
-
42
- task_counts = {task: 0 for task in self.losses}
43
- task_loss = {task: 0 for task in self.losses}
44
-
45
- results = {}
46
-
47
- results['loss'] = (loss * loss_weights).mean()
48
- results['total_loss'] = loss.detach().sum()
49
- results['total_loss_count'] = len(loss)
50
-
51
- task_counts = {task: 0 for task in self.losses}
52
- task_loss = {task: 0 for task in self.losses}
53
-
54
- for _loss, task in zip(loss.detach(), batch['task']):
55
- task_loss[task] += _loss
56
- task_counts[task] += 1
57
-
58
- for task in self.losses:
59
- if task_counts[task] > 0:
60
- results[f'{task}_loss'] = task_loss[task]
61
- results[f'{task}_loss_count'] = task_counts[task]
62
-
63
- return results
64
-
65
- @torch.no_grad()
66
- def valid_step(self, batch):
67
- self.eval()
68
- device = next(self.parameters()).device
69
- input_ids = batch['input_ids'].to(device)
70
-
71
- lm_labels = batch["target_ids"].to(device)
72
-
73
- loss_weights = batch["loss_weights"].to(device)
74
-
75
- output = self(
76
- input_ids=input_ids,
77
- labels=lm_labels,
78
- return_dict=True
79
- )
80
- assert 'loss' in output
81
-
82
- lm_mask = lm_labels != -100
83
- lm_mask = lm_mask.float()
84
- B, L = lm_labels.size()
85
-
86
- loss = output['loss']
87
-
88
- loss = loss.view(B, L) * lm_mask
89
-
90
- loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1)
91
-
92
- results = {}
93
-
94
- results['loss'] = (loss * loss_weights).mean()
95
- results['total_loss'] = loss.detach().sum()
96
- results['total_loss_count'] = len(loss)
97
-
98
- task_counts = {task: 0 for task in self.losses}
99
- task_loss = {task: 0 for task in self.losses}
100
-
101
- for _loss, task in zip(loss.detach(), batch['task']):
102
- task_loss[task] += _loss
103
- task_counts[task] += 1
104
-
105
- for task in self.losses:
106
- if task_counts[task] > 0:
107
- results[f'{task}_loss'] = task_loss[task]
108
- results[f'{task}_loss_count'] = task_counts[task]
109
-
110
- if 'rating' in self.losses:
111
- output = self.generate(
112
- input_ids=input_ids
113
- )
114
-
115
- generated_score = self.tokenizer.batch_decode(output, skip_special_tokens=True)
116
-
117
- results['rating_pred'] = generated_score
118
-
119
- return results
120
-
121
- @torch.no_grad()
122
- def generate_step(self, batch):
123
- self.eval()
124
- device = next(self.parameters()).device
125
- input_ids = batch['input_ids'].to(device)
126
-
127
- output = self.generate(
128
- input_ids=input_ids,
129
- )
130
-
131
- generated_sents = self.tokenizer.batch_decode(output, skip_special_tokens=True)
132
-
133
- return generated_sents