Commit
•
d0ad94e
1
Parent(s):
f345cd5
Upload folder using huggingface_hub
Browse files- README.md +1 -1
- __pycache__/config_tiny_mistral.cpython-310.pyc +0 -0
- __pycache__/dataloader.cpython-310.pyc +0 -0
- __pycache__/modeling_mistral.cpython-310.pyc +0 -0
- config_tiny_mistral.py +3 -2
- dataloader.py +1 -1
- modeling_mistral.py +7 -7
- run_train.py +2 -3
README.md
CHANGED
@@ -16,4 +16,4 @@ python config_tiny_mistral.py
|
|
16 |
# Run training
|
17 |
export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
|
18 |
torchrun --nproc_per_node=8 run_train.py --config-file config_tiny_mistral.yaml
|
19 |
-
```
|
|
|
16 |
# Run training
|
17 |
export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
|
18 |
torchrun --nproc_per_node=8 run_train.py --config-file config_tiny_mistral.yaml
|
19 |
+
```
|
__pycache__/config_tiny_mistral.cpython-310.pyc
ADDED
Binary file (3.99 kB). View file
|
|
__pycache__/dataloader.cpython-310.pyc
ADDED
Binary file (2.81 kB). View file
|
|
__pycache__/modeling_mistral.cpython-310.pyc
ADDED
Binary file (24.7 kB). View file
|
|
config_tiny_mistral.py
CHANGED
@@ -6,6 +6,8 @@ python config_tiny_mistral.py
|
|
6 |
```
|
7 |
"""
|
8 |
import os
|
|
|
|
|
9 |
|
10 |
from nanotron.config import (
|
11 |
CheckpointsArgs,
|
@@ -23,8 +25,6 @@ from nanotron.config import (
|
|
23 |
TokensArgs,
|
24 |
)
|
25 |
from nanotron.logging import human_format
|
26 |
-
from dataclasses import dataclass
|
27 |
-
from typing import Optional
|
28 |
|
29 |
|
30 |
@dataclass
|
@@ -58,6 +58,7 @@ class MistralConfig:
|
|
58 |
if self.num_key_value_heads is None:
|
59 |
self.num_key_value_heads = self.num_attention_heads
|
60 |
|
|
|
61 |
model_config = MistralConfig(
|
62 |
# Config for a tiny model model with 1.62M parameters
|
63 |
bos_token_id=1,
|
|
|
6 |
```
|
7 |
"""
|
8 |
import os
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import Optional
|
11 |
|
12 |
from nanotron.config import (
|
13 |
CheckpointsArgs,
|
|
|
25 |
TokensArgs,
|
26 |
)
|
27 |
from nanotron.logging import human_format
|
|
|
|
|
28 |
|
29 |
|
30 |
@dataclass
|
|
|
58 |
if self.num_key_value_heads is None:
|
59 |
self.num_key_value_heads = self.num_attention_heads
|
60 |
|
61 |
+
|
62 |
model_config = MistralConfig(
|
63 |
# Config for a tiny model model with 1.62M parameters
|
64 |
bos_token_id=1,
|
dataloader.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from nanotron.config import (
|
2 |
PretrainDatasetsArgs,
|
3 |
)
|
@@ -13,7 +14,6 @@ from nanotron.trainer import DistributedTrainer
|
|
13 |
from nanotron.utils import (
|
14 |
main_rank_first,
|
15 |
)
|
16 |
-
from nanotron import logging
|
17 |
|
18 |
try:
|
19 |
from huggingface_hub import __version__ as hf_hub_version
|
|
|
1 |
+
from nanotron import logging
|
2 |
from nanotron.config import (
|
3 |
PretrainDatasetsArgs,
|
4 |
)
|
|
|
14 |
from nanotron.utils import (
|
15 |
main_rank_first,
|
16 |
)
|
|
|
17 |
|
18 |
try:
|
19 |
from huggingface_hub import __version__ as hf_hub_version
|
modeling_mistral.py
CHANGED
@@ -23,16 +23,13 @@ from flash_attn.flash_attn_interface import (
|
|
23 |
flash_attn_with_kvcache,
|
24 |
)
|
25 |
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
26 |
-
from torch import nn
|
27 |
-
from transformers import MistralConfig
|
28 |
-
from transformers.activations import ACT2FN
|
29 |
-
|
30 |
from nanotron import distributed as dist
|
31 |
from nanotron import logging
|
32 |
from nanotron.config import ParallelismArgs, RecomputeGranularity
|
33 |
-
from nanotron.
|
34 |
from nanotron.logging import log_rank
|
35 |
from nanotron.models import NanotronModel
|
|
|
36 |
from nanotron.parallel import ParallelContext
|
37 |
from nanotron.parallel.parameters import NanotronParameter
|
38 |
from nanotron.parallel.pipeline_parallel.block import (
|
@@ -49,7 +46,9 @@ from nanotron.parallel.tensor_parallel.nn import (
|
|
49 |
)
|
50 |
from nanotron.random import RandomStates
|
51 |
from nanotron.utils import checkpoint_method
|
52 |
-
from
|
|
|
|
|
53 |
|
54 |
logger = logging.get_logger(__name__)
|
55 |
|
@@ -852,6 +851,7 @@ class MistralForTraining(NanotronModel):
|
|
852 |
):
|
853 |
super().__init__()
|
854 |
import warnings
|
|
|
855 |
warnings.warn("This is just a Llama Model, not a Mistral one for demo purpose. Please fix implementation")
|
856 |
self.model = MistralModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
|
857 |
self.loss = PipelineBlock(
|
@@ -1120,4 +1120,4 @@ def get_flops(
|
|
1120 |
else:
|
1121 |
raise ValueError("recompute_granularity must be one of 'full' or 'selective'")
|
1122 |
|
1123 |
-
return model_flops, hardware_flops
|
|
|
23 |
flash_attn_with_kvcache,
|
24 |
)
|
25 |
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
|
|
|
|
|
|
|
|
26 |
from nanotron import distributed as dist
|
27 |
from nanotron import logging
|
28 |
from nanotron.config import ParallelismArgs, RecomputeGranularity
|
29 |
+
from nanotron.generation.generate_store import AttachableStore
|
30 |
from nanotron.logging import log_rank
|
31 |
from nanotron.models import NanotronModel
|
32 |
+
from nanotron.nn.layer_norm import TritonRMSNorm
|
33 |
from nanotron.parallel import ParallelContext
|
34 |
from nanotron.parallel.parameters import NanotronParameter
|
35 |
from nanotron.parallel.pipeline_parallel.block import (
|
|
|
46 |
)
|
47 |
from nanotron.random import RandomStates
|
48 |
from nanotron.utils import checkpoint_method
|
49 |
+
from torch import nn
|
50 |
+
from transformers import MistralConfig
|
51 |
+
from transformers.activations import ACT2FN
|
52 |
|
53 |
logger = logging.get_logger(__name__)
|
54 |
|
|
|
851 |
):
|
852 |
super().__init__()
|
853 |
import warnings
|
854 |
+
|
855 |
warnings.warn("This is just a Llama Model, not a Mistral one for demo purpose. Please fix implementation")
|
856 |
self.model = MistralModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
|
857 |
self.loss = PipelineBlock(
|
|
|
1120 |
else:
|
1121 |
raise ValueError("recompute_granularity must be one of 'full' or 'selective'")
|
1122 |
|
1123 |
+
return model_flops, hardware_flops
|
run_train.py
CHANGED
@@ -9,11 +9,10 @@ torchrun --nproc_per_node=8 run_train.py --config-file config_tiny_mistral.yaml
|
|
9 |
"""
|
10 |
import argparse
|
11 |
|
12 |
-
from
|
13 |
from dataloader import get_dataloader
|
|
|
14 |
from nanotron.trainer import DistributedTrainer
|
15 |
-
from config_tiny_mistral import MistralConfig
|
16 |
-
|
17 |
|
18 |
|
19 |
def get_args():
|
|
|
9 |
"""
|
10 |
import argparse
|
11 |
|
12 |
+
from config_tiny_mistral import MistralConfig
|
13 |
from dataloader import get_dataloader
|
14 |
+
from modeling_mistral import MistralForTraining
|
15 |
from nanotron.trainer import DistributedTrainer
|
|
|
|
|
16 |
|
17 |
|
18 |
def get_args():
|