|
|
|
import torch |
|
import torch.nn as nn |
|
import math |
|
|
|
from utils import ( |
|
get_context_parallel_group, |
|
get_context_parallel_rank, |
|
get_context_parallel_world_size, |
|
get_context_parallel_group_rank, |
|
) |
|
|
|
|
|
def _conv_split(input_, dim=2, kernel_size=1): |
|
cp_world_size = get_context_parallel_world_size() |
|
|
|
|
|
if cp_world_size == 1: |
|
return input_ |
|
|
|
|
|
|
|
cp_rank = get_context_parallel_rank() |
|
|
|
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size |
|
|
|
if cp_rank == 0: |
|
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) |
|
else: |
|
|
|
output = input_.transpose(dim, 0)[ |
|
cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size |
|
].transpose(dim, 0) |
|
output = output.contiguous() |
|
|
|
|
|
|
|
return output |
|
|
|
|
|
def _conv_gather(input_, dim=2, kernel_size=1): |
|
cp_world_size = get_context_parallel_world_size() |
|
|
|
|
|
if cp_world_size == 1: |
|
return input_ |
|
|
|
group = get_context_parallel_group() |
|
cp_rank = get_context_parallel_rank() |
|
|
|
|
|
|
|
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous() |
|
if cp_rank == 0: |
|
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() |
|
else: |
|
input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous() |
|
|
|
tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [ |
|
torch.empty_like(input_) for _ in range(cp_world_size - 1) |
|
] |
|
if cp_rank == 0: |
|
input_ = torch.cat([input_first_kernel_, input_], dim=dim) |
|
|
|
tensor_list[cp_rank] = input_ |
|
torch.distributed.all_gather(tensor_list, input_, group=group) |
|
|
|
|
|
output = torch.cat(tensor_list, dim=dim).contiguous() |
|
|
|
|
|
|
|
return output |
|
|
|
|
|
def _cp_pass_from_previous_rank(input_, dim, kernel_size): |
|
|
|
if kernel_size == 1: |
|
return input_ |
|
|
|
group = get_context_parallel_group() |
|
cp_rank = get_context_parallel_rank() |
|
cp_group_rank = get_context_parallel_group_rank() |
|
cp_world_size = get_context_parallel_world_size() |
|
|
|
|
|
|
|
global_rank = torch.distributed.get_rank() |
|
global_world_size = torch.distributed.get_world_size() |
|
|
|
input_ = input_.transpose(0, dim) |
|
|
|
|
|
send_rank = global_rank + 1 |
|
recv_rank = global_rank - 1 |
|
if send_rank % cp_world_size == 0: |
|
send_rank -= cp_world_size |
|
if recv_rank % cp_world_size == cp_world_size - 1: |
|
recv_rank += cp_world_size |
|
|
|
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() |
|
if cp_rank < cp_world_size - 1: |
|
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) |
|
if cp_rank > 0: |
|
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) |
|
|
|
if cp_rank == 0: |
|
input_ = torch.cat([torch.zeros_like(input_[:1])] * (kernel_size - 1) + [input_], dim=0) |
|
else: |
|
req_recv.wait() |
|
input_ = torch.cat([recv_buffer, input_], dim=0) |
|
|
|
input_ = input_.transpose(0, dim).contiguous() |
|
return input_ |
|
|
|
|
|
def _drop_from_previous_rank(input_, dim, kernel_size): |
|
input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim) |
|
return input_ |
|
|
|
|
|
class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, input_, dim, kernel_size): |
|
ctx.dim = dim |
|
ctx.kernel_size = kernel_size |
|
return _conv_split(input_, dim, kernel_size) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None |
|
|
|
|
|
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, input_, dim, kernel_size): |
|
ctx.dim = dim |
|
ctx.kernel_size = kernel_size |
|
return _conv_gather(input_, dim, kernel_size) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None |
|
|
|
|
|
class _CPConvolutionPassFromPreviousRank(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, input_, dim, kernel_size): |
|
ctx.dim = dim |
|
ctx.kernel_size = kernel_size |
|
return _cp_pass_from_previous_rank(input_, dim, kernel_size) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None |
|
|
|
|
|
def conv_scatter_to_context_parallel_region(input_, dim, kernel_size): |
|
return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size) |
|
|
|
|
|
def conv_gather_from_context_parallel_region(input_, dim, kernel_size): |
|
return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size) |
|
|
|
|
|
def cp_pass_from_previous_rank(input_, dim, kernel_size): |
|
return _CPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size) |
|
|
|
|
|
|
|
|
|
|
|
|