File size: 4,593 Bytes
e79b770
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
import os,sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from typing import Dict

import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam


class Text2SemanticLightningModule(LightningModule):
    def __init__(self, config, output_dir,is_train=True):
        super().__init__()
        self.config = config
        self.top_k = 3
        self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
        pretrained_s1=config.get("pretrained_s1")
        if(pretrained_s1 and is_train):
            # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
            print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"]))
        if is_train:
            self.automatic_optimization = False
            self.save_hyperparameters()
            self.eval_dir = output_dir / 'eval'
            self.eval_dir.mkdir(parents=True, exist_ok=True)

    def training_step(self, batch: Dict, batch_idx: int):

        opt = self.optimizers()
        scheduler = self.lr_schedulers()
        loss, acc = self.model.forward(
            batch['phoneme_ids'], batch['phoneme_ids_len'],
            batch['semantic_ids'], batch['semantic_ids_len'],
            batch['bert_feature'])
        self.manual_backward(loss)
        if batch_idx > 0 and batch_idx % 4 == 0:
            opt.step()
            opt.zero_grad()
            scheduler.step()

        self.log(
            "total_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True)
        self.log(
            "lr",
            scheduler.get_last_lr()[0],
            on_epoch=True,
            prog_bar=True,
            sync_dist=True)
        self.log(
            f"top_{self.top_k}_acc",
            acc,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True)

    def validation_step(self, batch: Dict, batch_idx: int):return
        # # get loss
        # loss, acc = self.model.forward(
        #     batch['phoneme_ids'], batch['phoneme_ids_len'],
        #     batch['semantic_ids'], batch['semantic_ids_len'],
        #     batch['bert_feature']
        # )
        #
        # self.log(
        #     "val_total_loss",
        #     loss,
        #     on_step=True,
        #     on_epoch=True,
        #     prog_bar=True,
        #     sync_dist=True)
        # self.log(
        #     f"val_top_{self.top_k}_acc",
        #     acc,
        #     on_step=True,
        #     on_epoch=True,
        #     prog_bar=True,
        #     sync_dist=True)
        #
        # # get infer output
        # semantic_len = batch['semantic_ids'].size(1)
        # prompt_len = min(int(semantic_len * 0.5), 150)
        # prompt = batch['semantic_ids'][:, :prompt_len]
        # pred_semantic = self.model.infer(batch['phoneme_ids'],
        #                                  batch['phoneme_ids_len'], prompt,
        #                                  batch['bert_feature']
        #                                  )
        # save_name = f'semantic_toks_{batch_idx}.pt'
        # save_path = os.path.join(self.eval_dir, save_name)
        # torch.save(pred_semantic.detach().cpu(), save_path)

    def configure_optimizers(self):
        model_parameters = self.model.parameters()
        parameters_names = []
        parameters_names.append([
            name_param_pair[0]
            for name_param_pair in self.model.named_parameters()
        ])
        lm_opt = ScaledAdam(
            model_parameters,
            lr=0.01,
            betas=(0.9, 0.95),
            clipping_scale=2.0,
            parameters_names=parameters_names,
            show_dominant_parameters=False,
            clipping_update_period=1000, )

        return {
            "optimizer": lm_opt,
            "lr_scheduler": {
                "scheduler":
                WarmupCosineLRSchedule(
                    lm_opt,
                    init_lr=self.config['optimizer']['lr_init'],
                    peak_lr=self.config['optimizer']['lr'],
                    end_lr=self.config['optimizer']['lr_end'],
                    warmup_steps=self.config['optimizer']['warmup_steps'],
                    total_steps=self.config['optimizer']['decay_steps'])
            }
        }