train model
Browse files- scripts/train_model.py +2 -0
scripts/train_model.py
CHANGED
@@ -8,6 +8,7 @@ from transformers import DataCollatorForLanguageModeling
|
|
8 |
|
9 |
import torch
|
10 |
from torch.utils.data import DataLoader
|
|
|
11 |
|
12 |
|
13 |
x = input('Are you sure? [y/N] ')
|
@@ -17,6 +18,7 @@ if x not in ('y', 'Y', 'yes'):
|
|
17 |
|
18 |
|
19 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
20 |
|
21 |
|
22 |
def _batch_iterator():
|
|
|
8 |
|
9 |
import torch
|
10 |
from torch.utils.data import DataLoader
|
11 |
+
import torch.multiprocessing as mp
|
12 |
|
13 |
|
14 |
x = input('Are you sure? [y/N] ')
|
|
|
18 |
|
19 |
|
20 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
21 |
+
mp.set_start_method('spawn', force=True)
|
22 |
|
23 |
|
24 |
def _batch_iterator():
|