下述两个函数均可: 1.
def image_Tensor2ndarray(image_tensor: torch.Tensor):
"""
将tensor转化为cv2格式
"""
assert (len(image_tensor.shape) == 4 and image_tensor.shape[0] == 1)
image_tensor = image_tensor.clone().detach()
image_tensor = image_tensor.to(torch.device('cpu'))
image_tensor = image_tensor.squeeze()
image_cv2 = image_tensor.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).type(torch.uint8).numpy()
return image_cv2
def image_Tensor2np(image_tensor: torch.Tensor):
"""
将tensor转化为cv2格式
"""
assert (len(image_tensor.shape) == 4 and image_tensor.shape[0] == 1)
image_tensor = image_tensor.cpu().detach()
image_tensor = image_tensor.squeeze()
image_tensor = image_tensor.numpy()
maxValue = image_tensor.max()
image_tensor = image_tensor * 255 / maxValue
image_cv2 = np.uint8(image_tensor)
image_cv2 = image_cv2.transpose(1, 2, 0)
return image_cv2
|