在本手册中,我们将展示如何将Albumentations 应用于关键点增强问题。您可以对具有关键点的图像使用任何像素级增强,因为像素级增强不会影响关键点。
注意:默认情况下,与关键点一起工作的扩展不会在转换后改变关键点的标签。如果关键点的标签是特异性的,这可能会造成问题。例如,如果您有一个名为left arm 的关键点,并应用一个HorizontalFlip 增强,您将得到一个具有相同左臂标签的关键点,但它现在看起来像一个右臂关键点。
如果您使用这种类型的关键点,考虑使用来自albumentations-experimental的SymmetricKeypoints 扩展—正是为了处理这种情况而创建的实验性的扩展。pip install -U albumentations_experimental from albumentations_experimental import FlipSymmetricKeypoints 。
1.导入相关包
import random
import cv2
from matplotlib import pyplot as plt
import albumentations as A
KEYPOINT_COLOR = (0, 255, 0)
2.定义一个在图像上可视化关键点的函数
def vis_keypoints(image, keypoints, color=KEYPOINT_COLOR, diameter=15):
image = image.copy()
for (x, y) in keypoints:
cv2.circle(image, (int(x), int(y)), diameter, (0, 255, 0), -1)
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(image)
3.获得一个图像和它的注释
我们将对关键点的坐标使用xy格式。每个关键点用两个坐标定义,x是x轴上的位置,y是y轴上的位置。
image = cv2.imread('keypoints_image.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
keypoints = [
(100, 100),
(720, 410),
(1100, 400),
(1700, 30),
(300, 650),
(1570, 590),
(560, 800),
(1300, 750),
(900, 1000),
(910, 780),
(670, 670),
(830, 670),
(1000, 670),
(1150, 670),
(820, 900),
(1000, 900),
]
4.用关键点可视化原始图像
vis_keypoints(image, keypoints)
5.定义一个简单的数据增强管道
transform = A.Compose(
[A.HorizontalFlip(p=1)],
keypoint_params=A.KeypointParams(format='xy')
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])
6.下面是一些数据增强管道的例子
transform = A.Compose(
[A.VerticalFlip(p=1)],
keypoint_params=A.KeypointParams(format='xy')
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])
random.seed(7)
transform = A.Compose(
[A.RandomCrop(width=768, height=768, p=1)],
keypoint_params=A.KeypointParams(format='xy')
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])
random.seed(7)
transform = A.Compose(
[A.Rotate(p=0.5)],
keypoint_params=A.KeypointParams(format='xy')
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])
transform = A.Compose(
[A.CenterCrop(height=512, width=512, p=1)],
keypoint_params=A.KeypointParams(format='xy')
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])
random.seed(7)
transform = A.Compose(
[A.ShiftScaleRotate(p=0.5)],
keypoint_params=A.KeypointParams(format='xy')
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])
7.一个复杂的增强管道的例子
random.seed(7)
transform = A.Compose([
A.RandomSizedCrop(min_max_height=(256, 1025), height=512, width=512, p=0.5),
A.HorizontalFlip(p=0.5),
A.OneOf([
A.HueSaturationValue(p=0.5),
A.RGBShift(p=0.7)
], p=1),
A.RandomBrightnessContrast(p=0.5)
],
keypoint_params=A.KeypointParams(format='xy'),
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])
8.BONUS:Keras的数据增强
import numpy as np
import imageio
import os
import matplotlib.pyplot as plt
import pandas as pd
import albumentations as A
import cv2
import json
from tensorflow.python.keras.utils.data_utils import Sequence
def extract_coordinates(df):
full_coordinates = df['region_shape_attributes']
ls_coordinates = []
for coordinates in full_coordinates:
coordinates = json.loads(coordinates)
ls_coordinates.append([coordinates['cx'], coordinates['cy']])
return np.array(ls_coordinates, dtype=np.float32)
def rescale_image(image):
return (image / np.max(image) * 255.).astype(np.float32)
class CustomSeq(Sequence):
def __init__(self, path2imgs, df, batch_size, augmentations=None, mode='train'):
self.path2imgs = path2imgs
self.df = df
self.img_list = self.df['filename']
self.y = extract_coordinates(self.df)
self.batch_size = batch_size
self.augmentations = augmentations
self.mode = mode.lower()
def __len__(self):
return int(np.ceil(len(self.df) / float(self.batch_size)))
def on_epoch_end(self):
self.indexes = range(len(self.img_list))
if self.mode == 'train':
self.indexes = random.sample(self.indexes, k=len(self.indexes))
def get_batch_labels(self, idx, shapes):
y_batch = self.y[idx * self.batch_size: (idx+1) * self.batch_size]
return y_batch
def get_batch_images(self, idx):
x_batch = []
shapes = []
img_names = self.img_list[idx * self.batch_size: (idx+1) * self.batch_size]
for img_name in img_names:
image = imageio.imread(os.path.join(self.path2imgs, img_name))
image = rescale_image(image)
x_batch.append(image)
shapes.append(image.shape)
return x_batch, np.array(shapes)
def __getitem__(self, idx):
x_batch, shapes = self.get_batch_images(idx)
y_batch = self.get_batch_labels(idx, shapes)
if self.augmentations:
for i, (x_item, y_item) in enumerate(zip(x_batch, y_batch)):
transformed = self.augmentations(image=x_item, keypoints=np.expand_dims(y_item, axis=0))
x_batch[i], y_batch[i] = transformed['image'], transformed['keypoints'][0]
return x_batch, y_batch
transform = A.Compose([
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.ToFloat(max_value=255.)
], keypoint_params=A.KeypointParams(format='xy'))
path2png_imgs = os.getcwd()
df = pd.read_csv('vgg_annotate_crop.csv', header=0)
data = CustomSeq(path2png_imgs, df, 1, augmentations=transform)
images, point = data.__getitem__(0)
images[0] = cv2.circle(images[0], list(map(tuple, point.astype(np.int).tolist()))[0], 30, (1), -1)
plt.figure(figsize=(10, 10))
plt.imshow(images[0], cmap='gray')
参考目录
https://github.com/albumentations-team/albumentations_examples/blob/master/notebooks/example_keypoints.ipynb
|