PyTorch에서 Multi-GPU 학습하는 과정
1.
Model을 여러 GPU에 복사해서 할당
2.
Iteration 마다 batch를 GPU 개수만큼 나누기 (scatter)
3.
GPU forward
4.
Model의 출력들을 하나의 GPU로 모음 (gather)
Loss 계산 시
model.module.compute_loss()
Python
복사
Model을 GPU로 보내기 (DataParallel)
torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
어떤 module을 병렬화하는데, 이때 batch dimension으로만 chunking 하고 다른 object들은 각 device마다 하나씩 copy 한다.
from torch.nn import DataParallel
# if using pyg: from torch_geometric.nn.data_parallel import DataParallel
def to_cuda(model, args):
if args.gpu >= 0:
model = model.cuda(args.gpu)
if args.num_gpu > 1:
device_ids = [args.gpu + i for i in range(args.num_gpu)]
model = DataParallel(model, device_ids=device_ids)
return model
Python
복사