Model training anatomy
ã¢ãã«ãã¬ãŒãã³ã°ã®å¹çãåäžãããããã«é©çšã§ããããã©ãŒãã³ã¹æé©åãã¯ããã¯ãç解ããã«ã¯ããã¬ãŒãã³ã°äžã«GPUãã©ã®ããã«å©çšãããããããã³å®è¡ãããæäœã«å¿ããŠèšç®åŒ·åºŠãã©ã®ããã«å€åããããç解ããããšã圹ç«ã¡ãŸãã
ãŸãã¯ãGPUã®å©çšäŸãšã¢ãã«ã®ãã¬ãŒãã³ã°å®è¡ã«é¢ãã瀺åã«å¯ãäŸãæ¢æ±ããããšããå§ããŸãããããã¢ã³ã¹ãã¬ãŒã·ã§ã³ã®ããã«ãããã€ãã®ã©ã€ãã©ãªãã€ã³ã¹ããŒã«ããå¿ èŠããããŸã:
pip install transformers datasets accelerate nvidia-ml-py3
nvidia-ml-py3
ã©ã€ãã©ãªã¯ãPythonå
ããã¢ãã«ã®ã¡ã¢ãªäœ¿çšç¶æ³ãã¢ãã¿ãŒããããšãå¯èœã«ããŸãããããããã¿ãŒããã«ã§ã® nvidia-smi
ã³ãã³ãã«ã€ããŠã¯ãèããããããŸãããããã®ã©ã€ãã©ãªã䜿çšãããšãPythonããåãæ
å ±ã«ã¢ã¯ã»ã¹ã§ããŸãã
ãããããããã€ãã®ãããŒããŒã¿ãäœæããŸãã100ãã30000ã®éã®ã©ã³ãã ãªããŒã¯ã³IDãšãåé¡åšã®ããã®ãã€ããªã©ãã«ã§ããåèšã§ã512ã®ã·ãŒã±ã³ã¹ããããããããã®é·ãã¯512ã§ãPyTorchãã©ãŒãããã® Dataset
ã«æ ŒçŽãããŸãã
>>> import numpy as np
>>> from datasets import Dataset
>>> seq_len, dataset_size = 512, 512
>>> dummy_data = {
... "input_ids": np.random.randint(100, 30000, (dataset_size, seq_len)),
... "labels": np.random.randint(0, 1, (dataset_size)),
... }
>>> ds = Dataset.from_dict(dummy_data)
>>> ds.set_format("pt")
Trainerã䜿çšããŠGPUå©çšçãšãã¬ãŒãã³ã°å®è¡ã®èŠçŽçµ±èšæ å ±ã衚瀺ããããã«ã2ã€ã®ãã«ããŒé¢æ°ãå®çŸ©ããŸãã
>>> from pynvml import *
>>> def print_gpu_utilization():
... nvmlInit()
... handle = nvmlDeviceGetHandleByIndex(0)
... info = nvmlDeviceGetMemoryInfo(handle)
... print(f"GPU memory occupied: {info.used//1024**2} MB.")
>>> def print_summary(result):
... print(f"Time: {result.metrics['train_runtime']:.2f}")
... print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
... print_gpu_utilization()
以äžã¯ãç¡æã®GPUã¡ã¢ãªããéå§ããŠããããšã確èªããŸãããïŒ
>>> print_gpu_utilization()
GPU memory occupied: 0 MB.
GPUã¡ã¢ãªãã¢ãã«ãèªã¿èŸŒãåã®ããã«å æãããŠããªãããã«èŠããŸãããããã䜿ãã®ãã·ã³ã§ã®ç¶æ³ã§ãªãå Žåã¯ãGPUã¡ã¢ãªã䜿çšããŠãããã¹ãŠã®ããã»ã¹ãåæ¢ããŠãã ããããã ãããã¹ãŠã®ç©ºãGPUã¡ã¢ãªããŠãŒã¶ãŒã䜿çšã§ããããã§ã¯ãããŸãããã¢ãã«ãGPUã«èªã¿èŸŒãŸãããšãã«ãŒãã«ãèªã¿èŸŒãŸãã1ã2GBã®ã¡ã¢ãªã䜿çšããããšããããŸãããããã©ãããããã確èªããããã«ãGPUã«å°ããªãã³ãœã«ãèªã¿èŸŒããšãã«ãŒãã«ãèªã¿èŸŒãŸããŸãã
>>> import torch
>>> torch.ones((1, 1)).to("cuda")
>>> print_gpu_utilization()
GPU memory occupied: 1343 MB.
ã«ãŒãã«ã ãã§1.3GBã®GPUã¡ã¢ãªã䜿çšããŠããããšãããããŸãã次ã«ãã¢ãã«ãã©ãã ãã®ã¹ããŒã¹ã䜿çšããŠããããèŠãŠã¿ãŸãããã
Load Model
ãŸããgoogle-bert/bert-large-uncased
ã¢ãã«ãèªã¿èŸŒã¿ãŸããã¢ãã«ã®éã¿ãçŽæ¥GPUã«èªã¿èŸŒãããšã§ãéã¿ã ããã©ãã ãã®ã¹ããŒã¹ã䜿çšããŠãããã確èªã§ããŸãã
>>> from transformers import AutoModelForSequenceClassification
>>> model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-large-uncased").to("cuda")
>>> print_gpu_utilization()
GPU memory occupied: 2631 MB.
ã¢ãã«ã®éã¿ã ãã§ãGPUã¡ã¢ãªã1.3 GB䜿çšããŠããããšãããããŸããæ£ç¢ºãªæ°å€ã¯ã䜿çšããŠããå
·äœçãªGPUã«äŸåããŸããæ°ããGPUã§ã¯ãã¢ãã«ã®éã¿ãæé©åãããæ¹æ³ã§èªã¿èŸŒãŸãããããã¢ãã«ã®äœ¿çšãé«éåããããšããããããã¢ãã«ãããå€ãã®ã¹ããŒã¹ãå æããããšããããŸããããŠãnvidia-smi
CLIãšåãçµæãåŸãããããç°¡åã«ç¢ºèªããããšãã§ããŸãã
nvidia-smi
Tue Jan 11 08:58:05 2022 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 460.91.03 Driver Version: 460.91.03 CUDA Version: 11.2 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 Tesla V100-SXM2... On | 00000000:00:04.0 Off | 0 | | N/A 37C P0 39W / 300W | 2631MiB / 16160MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 3721 C ...nvs/codeparrot/bin/python 2629MiB | +-----------------------------------------------------------------------------+
ååãšåãæ°å€ãååŸãã16GBã®ã¡ã¢ãªãæèŒããV100 GPUã䜿çšããŠããããšãããããŸããããŠãã¢ãã«ã®ãã¬ãŒãã³ã°ãéå§ããGPUã¡ã¢ãªã®æ¶è²»ãã©ã®ããã«å€åãããã確èªããŠã¿ãŸãããããŸããããã€ãã®æšæºçãªãã¬ãŒãã³ã°åŒæ°ãèšå®ããŸã:
default_args = {
"output_dir": "tmp",
"evaluation_strategy": "steps",
"num_train_epochs": 1,
"log_level": "error",
"report_to": "none",
}
è€æ°ã®å®éšãå®è¡ããäºå®ãããå Žåãå®éšéã§ã¡ã¢ãªãé©åã«ã¯ãªã¢ããããã«ãå®éšã®éã« Python ã«ãŒãã«ãåèµ·åããŠãã ããã
Memory utilization at vanilla training
Trainer ã䜿çšããŠãGPU ããã©ãŒãã³ã¹ã®æé©åãã¯ããã¯ã䜿çšããã«ããããµã€ãº 4 ã§ã¢ãã«ããã¬ãŒãã³ã°ããŸãããïŒ
>>> from transformers import TrainingArguments, Trainer, logging
>>> logging.set_verbosity_error()
>>> training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
>>> trainer = Trainer(model=model, args=training_args, train_dataset=ds)
>>> result = trainer.train()
>>> print_summary(result)
Time: 57.82
Samples/second: 8.86
GPU memory occupied: 14949 MB.
æ¢ã«ãæ¯èŒçå°ããããããµã€ãºã§ããGPUã®ã»ãšãã©ã®ã¡ã¢ãªããã§ã«äœ¿çšãããŠããããšãããããŸãããããããã倧ããªããããµã€ãºã䜿çšããããšã¯ããã°ãã°ã¢ãã«ã®åæãéããªã£ãããæçµçãªæ§èœãåäžãããããããšããããŸãããããã£ãŠãçæ³çã«ã¯ãããããµã€ãºãã¢ãã«ã®èŠä»¶ã«åãããŠèª¿æŽãããã®ã§ãããGPUã®å¶éã«åãããŠèª¿æŽããå¿ èŠã¯ãããŸãããèå³æ·±ãããšã«ãã¢ãã«ã®ãµã€ãºãããã¯ããã«å€ãã®ã¡ã¢ãªã䜿çšããŠããŸãããªããããªãã®ããå°ãç解ããããã«ãã¢ãã«ã®æäœãšã¡ã¢ãªã®å¿ èŠæ§ãèŠãŠã¿ãŸãããã
Anatomy of Modelâs Operations
Transformerã¢ãŒããã¯ãã£ã«ã¯ãèšç®åŒ·åºŠã«ãã£ãŠä»¥äžã®3ã€ã®äž»èŠãªæäœã°ã«ãŒããå«ãŸããŠããŸãã
ãã³ãœã«ã®åçž®
ç·åœ¢å±€ãšMulti-Head Attentionã®ã³ã³ããŒãã³ãã¯ããã¹ãŠãããåŠçããã è¡å-è¡åã®ä¹ç® ãè¡ããŸãããããã®æäœã¯ãTransformerã®ãã¬ãŒãã³ã°ã«ãããŠæãèšç®éçŽçãªéšåã§ãã
çµ±èšçæ£èŠå
Softmaxãšå±€æ£èŠåã¯ããã³ãœã«ã®åçž®ãããèšç®è² è·ãå°ãªãã1ã€ãŸãã¯è€æ°ã® çž®çŽæäœ ãå«ã¿ããã®çµæãããããä»ããŠé©çšãããŸãã
èŠçŽ ããšã®æŒç®å
ãããã¯æ®ãã®æŒç®åã§ãïŒãã€ã¢ã¹ãããããã¢ãŠãã掻æ§åãããã³æ®å·®æ¥ç¶ ã§ãããããã¯æãèšç®éçŽçãªæäœã§ã¯ãããŸããã
ããã©ãŒãã³ã¹ã®ããã«ããã¯ãåæããéã«ããã®ç¥èã¯åœ¹ç«ã€ããšããããŸãã
ãã®èŠçŽã¯ãData Movement Is All You Need: Optimizing Transformers 2020ã«é¢ããã±ãŒã¹ã¹ã¿ãã£ãã掟çããŠããŸãã
Anatomy of Modelâs Memory
ã¢ãã«ã®ãã¬ãŒãã³ã°ãGPUã«é 眮ãããã¢ãã«ãããã¯ããã«å€ãã®ã¡ã¢ãªã䜿çšããããšãèŠãŠããŸãããããã¯ããã¬ãŒãã³ã°äžã«GPUã¡ã¢ãªã䜿çšããå€ãã®ã³ã³ããŒãã³ããååšããããã§ããGPUã¡ã¢ãªäžã®ã³ã³ããŒãã³ãã¯ä»¥äžã®éãã§ãïŒ
- ã¢ãã«ã®éã¿
- ãªããã£ãã€ã¶ã®ç¶æ
- åŸé
- åŸé èšç®ã®ããã«ä¿åãããååã掻æ§å
- äžæãããã¡
- æ©èœåºæã®ã¡ã¢ãª
éåžžãAdamWã䜿çšããŠæ··å粟床ã§ãã¬ãŒãã³ã°ãããã¢ãã«ã¯ãã¢ãã«ãã©ã¡ãŒã¿ããšã«18ãã€ããšã¢ã¯ãã£ããŒã·ã§ã³ã¡ã¢ãªãå¿ èŠã§ããæšè«ã§ã¯ãªããã£ãã€ã¶ã®ç¶æ ãšåŸé ã¯äžèŠã§ãã®ã§ãããããå·®ãåŒãããšãã§ããŸãããããã£ãŠãæ··å粟床ã®æšè«ã«ãããŠã¯ãã¢ãã«ãã©ã¡ãŒã¿ããšã«6ãã€ããšã¢ã¯ãã£ããŒã·ã§ã³ã¡ã¢ãªãå¿ èŠã§ãã
詳现ãèŠãŠã¿ãŸãããã
ã¢ãã«ã®éã¿:
- fp32ãã¬ãŒãã³ã°ã®ãã©ã¡ãŒã¿ãŒæ° * 4ãã€ã
- ããã¯ã¹ãã¬ã·ãžã§ã³ãã¬ãŒãã³ã°ã®ãã©ã¡ãŒã¿ãŒæ° * 6ãã€ãïŒã¡ã¢ãªå ã«fp32ãšfp16ã®ã¢ãã«ãç¶æïŒ
ãªããã£ãã€ã¶ã®ç¶æ :
- éåžžã®AdamWã®ãã©ã¡ãŒã¿ãŒæ° * 8ãã€ãïŒ2ã€ã®ç¶æ ãç¶æïŒ
- 8-bit AdamWãªããã£ãã€ã¶ã®ãã©ã¡ãŒã¿ãŒæ° * 2ãã€ãïŒbitsandbytesã®ãããªãªããã£ãã€ã¶ïŒ
- ã¢ãŒã¡ã³ã¿ã ãæã€SGDã®ãããªãªããã£ãã€ã¶ã®ãã©ã¡ãŒã¿ãŒæ° * 4ãã€ãïŒ1ã€ã®ç¶æ ãç¶æïŒ
åŸé
- fp32ãŸãã¯ããã¯ã¹ãã¬ã·ãžã§ã³ãã¬ãŒãã³ã°ã®ãã©ã¡ãŒã¿ãŒæ° * 4ãã€ãïŒåŸé ã¯åžžã«fp32ã§ä¿æïŒ
ãã©ã¯ãŒãã¢ã¯ãã£ããŒã·ã§ã³
- ãµã€ãºã¯å€ãã®èŠå ã«äŸåããäž»èŠãªèŠå ã¯ã·ãŒã±ã³ã¹ã®é·ããé ãå±€ã®ãµã€ãºãããã³ããããµã€ãºã§ãã
ãã©ã¯ãŒããšããã¯ã¯ãŒãã®é¢æ°ã«ãã£ãŠæž¡ãããè¿ãããå ¥åãšåºåãããã³åŸé èšç®ã®ããã«ä¿åããããã©ã¯ãŒãã¢ã¯ãã£ããŒã·ã§ã³ããããŸãã
äžæçãªã¡ã¢ãª
ããã«ãèšç®ãå®äºããåŸã«è§£æŸãããããŸããŸãªäžæå€æ°ããããŸããããããã¯äžæçã«è¿œå ã®ã¡ã¢ãªãå¿ èŠãšããOOMã«éããå¯èœæ§ããããŸãããããã£ãŠãã³ãŒãã£ã³ã°æã«ã¯ãã®ãããªäžæå€æ°ã«æŠç¥çã«èããå¿ èŠãªããªã£ããæ瀺çã«è§£æŸããããšãéåžžã«éèŠã§ãã
æ©èœåºæã®ã¡ã¢ãª
次ã«ããœãããŠã§ã¢ã«ã¯ç¹å¥ãªã¡ã¢ãªèŠä»¶ãããå ŽåããããŸããããšãã°ãããŒã ãµãŒãã䜿çšããŠããã¹ããçæããå ŽåããœãããŠã§ã¢ã¯è€æ°ã®å ¥åãšåºåã®ã³ããŒãç¶æããå¿ èŠããããŸãã
forward
ãšbackward
ã®å®è¡é床
ç³ã¿èŸŒã¿å±€ãšç·åœ¢å±€ã§ã¯ãããã¯ã¯ãŒãã«ãã©ã¯ãŒããšæ¯ã¹ãŠ2åã®FLOPSããããäžè¬çã«ã¯çŽ2åé ããªããŸãïŒããã¯ã¯ãŒãã®ãµã€ãºãäžäŸ¿ã§ããããšãããããããã以äžã«ãªãããšããããŸãïŒã ã¢ã¯ãã£ããŒã·ã§ã³ã¯éåžžããã³ãå¹ å¶éãããŠãããããã¯ã¯ãŒãã§ã¢ã¯ãã£ããŒã·ã§ã³ããã©ã¯ãŒããããå€ãã®ããŒã¿ãèªãããšãäžè¬çã§ãïŒããšãã°ãã¢ã¯ãã£ããŒã·ã§ã³ãã©ã¯ãŒãã¯1åèªã¿åãã1åæžã蟌ã¿ãã¢ã¯ãã£ããŒã·ã§ã³ããã¯ã¯ãŒãã¯ãã©ã¯ãŒãã®gradOutputããã³åºåã2åèªã¿åãã1åæžã蟌ã¿ãŸãïŒã
ã芧ã®éããGPUã¡ã¢ãªãç¯çŽãããæäœãé«éåã§ããå¯èœæ§ã®ããããã€ãã®å ŽæããããŸãã GPUã®å©çšãšèšç®é床ã«åœ±é¿ãäžããèŠå ãç解ããã®ã§ãããã©ãŒãã³ã¹æé©åã®æè¡ã«ã€ããŠã¯ãåäžGPUã§ã®å¹ççãªãã¬ãŒãã³ã°ã®ããã®æ¹æ³ãšããŒã«ã®ããã¥ã¡ã³ããŒã·ã§ã³ããŒãžãåç §ããŠãã ããã
詳现ãèŠãŠã¿ãŸãããã
< > Update on GitHub