示例地址:https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
完整代码如下:
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
# On Windows platform, the torch.distributed package only
# supports Gloo backend, FileStore and TcpStore.
# For FileStore, set init_method parameter in init_process_group
# to a local file. Example as follow:
# init_method="file:///f:/libtmp/some_file"
# dist.init_process_group(
# "gloo",
# rank=rank,
# init_method=init_method,
# world_size=world_size)
# For TcpStore, same way as on Linux.
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
print(f'setup rank{rank} world_size={world_size}')
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank,world_size):
print(f"Running basic DDP example on rank {rank} world_size {world_size}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10)) #应该是 outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),#这个参数只能这样传,rank通过start_processes函数中的for循环和nprocs传递过去了
nprocs=world_size,
join=True)
# def start_processes(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'):
# _python_version_check()
# mp = multiprocessing.get_context(start_method)
# error_queues = []
# processes = []
# for i in range(nprocs):
# error_queue = mp.SimpleQueue()
# process = mp.Process(
# target=_wrap,
# args=(fn, i, args, error_queue),
# daemon=daemon,
# )
# process.start()
# error_queues.append(error_queue)
# processes.append(process)
# 这里demo_basic(rank,world_size)有两个参数,但
# 参考start_processes函数后,
# mp.spawn(demo_fn,
# args=(world_size,),#这个参数只能这样传,
# nprocs=world_size,
# join=True)
# demo_basic(rank,world_size)中的rank通过start_processes函数中的for循环和nprocs传递过去了
if __name__ == '__main__':
demo_fn=demo_basic
world_size=4
run_demo(demo_fn, world_size)
输出:
Running basic DDP example on rank 2 world_size 4. Running basic DDP example on rank 3 world_size 4. Running basic DDP example on rank 0 world_size 4. Running basic DDP example on rank 1 world_size 4. setup rank0 world_size=4 setup rank1 world_size=4 setup rank2 world_size=4 setup rank3 world_size=4
|