Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,930 Bytes
f0533a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import torch
import torch.distributed as dist
from .utils import is_dist_avail_and_initialized, get_rank
SEQ_PARALLEL_GROUP = None
SEQ_PARALLEL_SIZE = None
SEQ_PARALLEL_PROC_NUM = None # using how many process for sequence parallel
SYNC_INPUT_GROUP = None
SYNC_INPUT_SIZE = None
def is_sequence_parallel_initialized():
if SEQ_PARALLEL_GROUP is None:
return False
else:
return True
def init_sequence_parallel_group(args):
global SEQ_PARALLEL_GROUP
global SEQ_PARALLEL_SIZE
global SEQ_PARALLEL_PROC_NUM
assert SEQ_PARALLEL_GROUP is None, "sequence parallel group is already initialized"
assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized"
SEQ_PARALLEL_SIZE = args.sp_group_size
print(f"Setting the Sequence Parallel Size {SEQ_PARALLEL_SIZE}")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if args.sp_proc_num == -1:
SEQ_PARALLEL_PROC_NUM = world_size
else:
SEQ_PARALLEL_PROC_NUM = args.sp_proc_num
assert SEQ_PARALLEL_PROC_NUM % SEQ_PARALLEL_SIZE == 0, "The process needs to be evenly divided"
for i in range(0, SEQ_PARALLEL_PROC_NUM, SEQ_PARALLEL_SIZE):
ranks = list(range(i, i + SEQ_PARALLEL_SIZE))
group = torch.distributed.new_group(ranks)
if rank in ranks:
SEQ_PARALLEL_GROUP = group
break
def init_sync_input_group(args):
global SYNC_INPUT_GROUP
global SYNC_INPUT_SIZE
assert SYNC_INPUT_GROUP is None, "parallel group is already initialized"
assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized"
SYNC_INPUT_SIZE = args.max_frames
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
for i in range(0, world_size, SYNC_INPUT_SIZE):
ranks = list(range(i, i + SYNC_INPUT_SIZE))
group = torch.distributed.new_group(ranks)
if rank in ranks:
SYNC_INPUT_GROUP = group
break
def get_sequence_parallel_group():
assert SEQ_PARALLEL_GROUP is not None, "sequence parallel group is not initialized"
return SEQ_PARALLEL_GROUP
def get_sync_input_group():
return SYNC_INPUT_GROUP
def get_sequence_parallel_world_size():
assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
return SEQ_PARALLEL_SIZE
def get_sequence_parallel_rank():
assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
rank = get_rank()
cp_rank = rank % SEQ_PARALLEL_SIZE
return cp_rank
def get_sequence_parallel_group_rank():
assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
rank = get_rank()
cp_group_rank = rank // SEQ_PARALLEL_SIZE
return cp_group_rank
def get_sequence_parallel_proc_num():
return SEQ_PARALLEL_PROC_NUM
|