File size: 2,441 Bytes
bcb1848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
"""
A
"""
# ๋ฏธ๋ฆฌ ์ค์ ๋ ์ธ์๋ค
from argparse import ArgumentParser
# ์ฌ์ฉ์ ์ ์ ๋ณ์๋ค
ROOT_DIR = "" # ํ๋ก์ ํธ ๋ฃจํธ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
BERT_PRETRAINED_DIR = "klue/roberta-large" # BERT ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
DATA_PREFIX = "data" # ๋ฐ์ดํฐ ํ์ผ๋ค์ ์์ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
CHECKPOINT_DIR = 'model' # ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
LOG_FATH = 'logs' # ํ๋ จ ๋ก๊ทธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
def get_train_args():
"""
ํ๋ จ ์ธ์ ์ค์
"""
parser = ArgumentParser(description='I_S', allow_abbrev=False)
# ์ธ์ ํ์ฑ
parser.add_argument('--model_name', type=str, default='KCSN')
# ๋ชจ๋ธ ์ค์
parser.add_argument('--pooling_type', type=str, default='max_pooling')
parser.add_argument('--classifier_intermediate_dim', type=int, default=100)
parser.add_argument('--nonlinear_type', type=str, default='tanh')
# BERT ์ค์
parser.add_argument('--bert_pretrained_dir', type=str, default=BERT_PRETRAINED_DIR)
# ํ๋ จ ์ค์
parser.add_argument('--margin', type=float, default=1.0)
parser.add_argument('--lr', type=float, default=2e-5)
parser.add_argument('--optimizer', type=str, default='adam')
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--num_epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--lr_decay', type=float, default=0.95)
parser.add_argument('--patience', type=int, default=10)
# ํ๋ จ, ๊ฐ๋ฐ ๋ฐ ํ
์คํธ ๋ฐ์ดํฐ ํ์ผ ๊ฒฝ๋ก
parser.add_argument('--train_file', type=str, default=f'{DATA_PREFIX}/train_unsplit.txt')
parser.add_argument('--dev_file', type=str, default=f'{DATA_PREFIX}/dev_unsplit.txt')
parser.add_argument('--test_file', type=str, default=f'{DATA_PREFIX}/test_unsplit.txt')
parser.add_argument('--name_list_path', type=str, default=f'{DATA_PREFIX}/name_list.txt')
parser.add_argument('--ws', type=int, default=10) # ์๋์ฐ ํฌ๊ธฐ
parser.add_argument('--length_limit', type=int, default=510) # ์ํ์ค ๊ธธ์ด ์ ํ
# ์ฒดํฌํฌ์ธํธ ๋ฐ ๋ก๊ทธ ์ ์ฅ ๋๋ ํ ๋ฆฌ
parser.add_argument('--checkpoint_dir', type=str, default=CHECKPOINT_DIR)
parser.add_argument('--training_logs', type=str, default=LOG_FATH)
args, _ = parser.parse_known_args()
return args
|