File size: 3,727 Bytes
5eaf982 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from modeling_p5 import P5
class P5Pretraining(P5):
def __init__(self, config):
super().__init__(config)
self.losses = self.config.losses.split(',')
def train_step(self, batch):
device = next(self.parameters()).device
input_ids = batch['input_ids'].to(device)
whole_word_ids = batch['whole_word_ids'].to(device)
lm_labels = batch["target_ids"].to(device)
loss_weights = batch["loss_weights"].to(device)
output = self(
input_ids=input_ids,
whole_word_ids=whole_word_ids,
labels=lm_labels,
return_dict=True
)
assert 'loss' in output
lm_mask = lm_labels != -100
lm_mask = lm_mask.float()
B, L = lm_labels.size()
loss = output['loss']
loss = loss.view(B, L) * lm_mask
loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1)
task_counts = {task: 0 for task in self.losses}
task_loss = {task: 0 for task in self.losses}
results = {}
results['loss'] = (loss * loss_weights).mean()
results['total_loss'] = loss.detach().sum()
results['total_loss_count'] = len(loss)
task_counts = {task: 0 for task in self.losses}
task_loss = {task: 0 for task in self.losses}
for _loss, task in zip(loss.detach(), batch['task']):
task_loss[task] += _loss
task_counts[task] += 1
for task in self.losses:
if task_counts[task] > 0:
results[f'{task}_loss'] = task_loss[task]
results[f'{task}_loss_count'] = task_counts[task]
return results
@torch.no_grad()
def valid_step(self, batch):
self.eval()
device = next(self.parameters()).device
input_ids = batch['input_ids'].to(device)
lm_labels = batch["target_ids"].to(device)
loss_weights = batch["loss_weights"].to(device)
output = self(
input_ids=input_ids,
labels=lm_labels,
return_dict=True
)
assert 'loss' in output
lm_mask = lm_labels != -100
lm_mask = lm_mask.float()
B, L = lm_labels.size()
loss = output['loss']
loss = loss.view(B, L) * lm_mask
loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1)
results = {}
results['loss'] = (loss * loss_weights).mean()
results['total_loss'] = loss.detach().sum()
results['total_loss_count'] = len(loss)
task_counts = {task: 0 for task in self.losses}
task_loss = {task: 0 for task in self.losses}
for _loss, task in zip(loss.detach(), batch['task']):
task_loss[task] += _loss
task_counts[task] += 1
for task in self.losses:
if task_counts[task] > 0:
results[f'{task}_loss'] = task_loss[task]
results[f'{task}_loss_count'] = task_counts[task]
if 'rating' in self.losses:
output = self.generate(
input_ids=input_ids
)
generated_score = self.tokenizer.batch_decode(output, skip_special_tokens=True)
results['rating_pred'] = generated_score
return results
@torch.no_grad()
def generate_step(self, batch):
self.eval()
device = next(self.parameters()).device
input_ids = batch['input_ids'].to(device)
output = self.generate(
input_ids=input_ids,
)
generated_sents = self.tokenizer.batch_decode(output, skip_special_tokens=True)
return generated_sents
|