- Pytorch默认的交叉熵函数使用
loss=(pred=浮点数, target=整数) 的形式
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)
output.backward()
def cross_entropy(pred, soft_targets):
logsoftmax = nn.LogSoftmax()
return torch.mean(torch.sum(- soft_targets * logsoftmax(pred), 1))
Reference:
- https://discuss.pytorch.org/t/how-should-i-implement-cross-entropy-loss-with-continuous-target-outputs/10720/18
- https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
- https://blog.csdn.net/weixin_39529413/article/details/123122330
|