File size: 5,059 Bytes
e79b770
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# test from dump file
import argparse
import time
from pathlib import Path

import numpy as np
import torch
from AR.data.dataset import Text2SemanticDataset
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
from torch.utils.data import DataLoader


def parse_args():
    # parse args and config
    parser = argparse.ArgumentParser(
        description="Run SoundStorm AR S1 model for test set.")

    parser.add_argument(
        '--config_file',
        type=str,
        default='conf/default.yaml',
        help='path of config file')

    # args for dataset
    parser.add_argument(
        '--test_semantic_path',
        type=str,
        default='dump/test/semantic_token.tsv')
    parser.add_argument(
        '--test_phoneme_path', type=str, default='dump/test/phonemes.npy')

    parser.add_argument(
        '--ckpt_path',
        type=str,
        default='exp/default/ckpt/epoch=99-step=49000.ckpt',
        help='Checkpoint file of SoundStorm AR S1 model.')

    parser.add_argument("--output_dir", type=str, help="output dir.")

    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    config = load_yaml_config(args.config_file)

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    batch_size = 1
    hz = 50
    max_sec = config['data']['max_sec']

    # get dataset
    test_dataset = Text2SemanticDataset(
        phoneme_path=args.test_phoneme_path,
        semantic_path=args.test_semantic_path,
        # max_sec 需要与训练时保持一致,不然可能会效果不好,重复漏字等
        # 但是这里设置太短又会直接过滤掉太长的样本,为了防止被过滤掉,可以在 infer 的时候截断
        max_sec=100,
        max_sample=8,
        pad_val=config['data']['pad_val'])
    # get model
    t2s_model = Text2SemanticLightningModule.load_from_checkpoint(
        checkpoint_path=args.ckpt_path, config=config)
    t2s_model.cuda()
    t2s_model.eval()

    # 获取 batch_size 条
    # 创建 DataLoader,并指定 collate_fn 函数
    dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=test_dataset.collate)

    item_names = test_dataset.__get_item_names__()

    # 逐批次读取数据, bs=1、shuffle=False 时可以用 __get_item_names__ 对应
    semantic_data = [['item_name', 'semantic_audio']]
    for i, batch in enumerate(dataloader):
        # 要保证 bs = 1
        utt_id = item_names[i]
        if i == 0:
            print("utt_id:", utt_id)
            # bs > 1 时会补零
            # 与 validation_step() 保持一致
            semantic_len = batch['semantic_ids'].size(1)
            # 以 batch['semantic_ids'] 的前 150 个为 prompt
            # 多次合成,前 prompt_len 个是一样的,而且和 prompt 一样
            prompt_len = min(int(semantic_len * 0.5), 150)
            # 输入纯文本时 prompt 该输入什么?=> see t2s.py
            prompt = batch['semantic_ids'][:, :prompt_len]
            # # zero prompt => 也可以输出文本内容正确的 semantic token, 但是音色是乱的
            # 证明 semantic token 中还是包含了音色信息
            # prompt = torch.ones(
            #     batch['semantic_ids'].size(0), 1, dtype=torch.int32) * 0
            # print("prompt:", prompt)
            # print("prompt.shape:", prompt.shape)
            np.save(output_dir / 'prompt.npy', prompt.detach().cpu().numpy())

            st = time.time()
            with torch.no_grad():
                # calculate acc for test
                loss, acc = t2s_model.model.forward(
                    batch['phoneme_ids'].cuda(),
                    batch['phoneme_ids_len'].cuda(),
                    batch['semantic_ids'].cuda(),
                    batch['semantic_ids_len'].cuda())
                print("top_3_acc of this batch:", acc)
                pred_semantic = t2s_model.model.infer(
                    batch['phoneme_ids'].cuda(),
                    batch['phoneme_ids_len'].cuda(),
                    prompt.cuda(),
                    top_k=config['inference']['top_k'],
                    # hz * max_sec in train dataloader
                    # 生成的长度是 1002 应该是有一些 pad
                    early_stop_num=hz * max_sec)
                # bs = 1
                pred_semantic = pred_semantic[0]
            print(f'{time.time() - st} sec used in T2S')
            semantic_token = pred_semantic.detach().cpu().numpy().tolist()
            semantic_token_str = ' '.join(str(x) for x in semantic_token)
            semantic_data.append([utt_id, semantic_token_str])
        else:
            break
    delimiter = '\t'
    filename = output_dir / "semantic_token.tsv"
    with open(filename, 'w', encoding='utf-8') as writer:
        for row in semantic_data:
            line = delimiter.join(row)
            writer.write(line + '\n')


if __name__ == "__main__":
    main()