方法一
在模型、损失函数和数据中使用cuda
model1 = model()
if torch.cuda.is_available():
model1 = model1.cuda()
bloss = nn.CrossEntropyLoss()
if torch.cuda.is_available():
bloss = bloss.cuda()
for data in train_dataloader:
imgs, targets = data
if torch.cuda.is_available():
imgs = imgs.cuda()
targets = targets.cuda()
for data in test_dataloader:
imgs, targets = data
if torch.cuda.is_available():
imgs = imgs.cuda()
targets = targets.cuda()
方法二
device = torch.device("cuda")
model1 = model()
model1 = model1.to(device)
损失函数和数据也按照上述方法处理即可
|