广播
import torch
import math
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)
p = torch.tensor([1, 2, 3])
xx = x.unsqueeze(-1).pow(p)
函数
损失函数
损失函数的参数基本一样
比如对于nn.MSELoss函数,其中的reduction函数;
import torch
import torch.nn as nn
a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)
b = torch.tensor([[3, 5], [8, 6]], dtype=torch.float)
loss_fn1 = torch.nn.MSELoss(reduction='none')
loss1 = loss_fn1(a.float(), b.float())
print(loss1)
loss_fn2 = torch.nn.MSELoss(reduction='sum')
loss2 = loss_fn2(a.float(), b.float())
print(loss2)
loss_fn3 = torch.nn.MSELoss(reduction='mean')
loss3 = loss_fn3(a.float(), b.float())
print(loss3)
|