Source code for torchpack.distributed.comm

import pickle
from typing import Any, List

import torch
import torch.distributed

from torchpack.distributed import context

__all__ = ['allreduce', 'allgather', 'barrier']


[docs]def allreduce(data: Any, reduction: str = 'sum') -> Any: data = allgather(data) if reduction == 'sum': return sum(data)
[docs]def allgather(data: Any) -> List: world_size = context.size() if world_size == 1: return [data] # serialized to a tensor buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).cuda() # obtain tensor size of each rank local_size = torch.LongTensor([tensor.numel()]).cuda() sizes = [torch.LongTensor([0]).cuda() for _ in range(world_size)] torch.distributed.all_gather(sizes, local_size) sizes = [int(size.item()) for size in sizes] max_size = max(sizes) # receiving tensors from all ranks tensors = [torch.ByteTensor(size=(max_size, )).cuda() for _ in sizes] if local_size != max_size: padding = torch.ByteTensor(size=(max_size - local_size, )).cuda() tensor = torch.cat((tensor, padding), dim=0) torch.distributed.all_gather(tensors, tensor) data = [] for size, tensor in zip(sizes, tensors): buffer = tensor.cpu().numpy().tobytes()[:size] data.append(pickle.loads(buffer)) return data
[docs]def barrier() -> None: world_size = context.size() if world_size == 1: return torch.distributed.barrier()