在图像分割、实例分割等需要图像和对应掩码共同参与训练的任务中,数据增强时如果图像和掩码应用了不同的随机变换,会导致二者的空间位置不匹配,使得掩码标注失去意义,直接影响模型的训练效果。因此确保二者应用完全相同的随机变换是这类任务数据预处理的核心要求。

核心思路:复用随机变换参数
随机变换的不确定性来自随机参数的生成,比如随机水平翻转的概率、随机裁剪的起始坐标、随机旋转的角度等。只要让图像和掩码使用同一套随机参数执行变换,就能保证二者的变换完全一致。下面介绍两种常用的实现方式。
方法一:自定义同步变换类
我们可以把随机变换的逻辑封装到一个自定义类中,在初始化时生成随机参数,然后分别对图像和掩码执行相同的变换操作。以下是一个同步随机水平翻转和随机旋转的示例:
import torch
import torchvision.transforms.functional as TF
import random
class SyncRandomTransform:
def __init__(self, hflip_prob=0.5, rotate_degree_range=(-10, 10)):
self.hflip_prob = hflip_prob
self.rotate_degree_range = rotate_degree_range
def __call__(self, image, mask):
# 生成随机水平翻转参数
hflip_flag = random.random() < self.hflip_prob
# 生成随机旋转角度
rotate_degree = random.uniform(self.rotate_degree_range[0], self.rotate_degree_range[1])
# 对图像执行变换
if hflip_flag:
image = TF.hflip(image)
if rotate_degree != 0:
image = TF.rotate(image, rotate_degree)
# 对掩码执行完全相同的变换
if hflip_flag:
mask = TF.hflip(mask)
if rotate_degree != 0:
mask = TF.rotate(mask, rotate_degree)
return image, mask
# 使用示例
transform = SyncRandomTransform(hflip_prob=0.5, rotate_degree_range=(-15, 15))
image = torch.randn(3, 256, 256) # 模拟3通道256x256图像
mask = torch.randn(1, 256, 256) # 模拟单通道256x256掩码
aug_image, aug_mask = transform(image, mask)
print("增强后图像形状:", aug_image.shape)
print("增强后掩码形状:", aug_mask.shape)
方法二:复用torchvision变换的随机状态
如果使用torchvision内置的随机变换类,可以通过先生成随机参数,再手动调用变换函数的方式实现同步。以随机裁剪为例:
import torch
import torchvision.transforms.functional as TF
import random
def sync_random_crop(image, mask, crop_size=(224, 224)):
# 获取图像的高度和宽度
_, h, w = image.shape
crop_h, crop_w = crop_size
# 生成随机裁剪的起始坐标
top = random.randint(0, h - crop_h)
left = random.randint(0, w - crop_w)
# 对图像和掩码执行相同的裁剪操作
image_crop = TF.crop(image, top, left, crop_h, crop_w)
mask_crop = TF.crop(mask, top, left, crop_h, crop_w)
return image_crop, mask_crop
# 使用示例
image = torch.randn(3, 256, 256)
mask = torch.randn(1, 256, 256)
crop_image, crop_mask = sync_random_crop(image, mask, crop_size=(224, 224))
print("裁剪后图像形状:", crop_image.shape)
print("裁剪后掩码形状:", crop_mask.shape)
组合多种变换的实现
如果需要对图像和掩码同时应用多种随机变换,可以把所有变换的参数生成逻辑放在一个统一的调用流程中,依次执行相同的变换步骤。以下是一个组合水平翻转、随机旋转、随机裁剪的完整示例:
import torch
import torchvision.transforms.functional as TF
import random
def sync_augmentation(image, mask, hflip_prob=0.5, rotate_range=(-10, 10), crop_size=(224, 224)):
# 1. 随机水平翻转
hflip_flag = random.random() < hflip_prob
if hflip_flag:
image = TF.hflip(image)
mask = TF.hflip(mask)
# 2. 随机旋转
rotate_degree = random.uniform(rotate_range[0], rotate_range[1])
if rotate_degree != 0:
image = TF.rotate(image, rotate_degree)
mask = TF.rotate(mask, rotate_degree)
# 3. 随机裁剪
_, h, w = image.shape
crop_h, crop_w = crop_size
if h >= crop_h and w >= crop_w:
top = random.randint(0, h - crop_h)
left = random.randint(0, w - crop_w)
image = TF.crop(image, top, left, crop_h, crop_w)
mask = TF.crop(mask, top, left, crop_h, crop_w)
return image, mask
# 使用示例
image = torch.randn(3, 256, 256)
mask = torch.randn(1, 256, 256)
aug_image, aug_mask = sync_augmentation(image, mask)
print("最终增强图像形状:", aug_image.shape)
print("最终增强掩码形状:", aug_mask.shape)
注意事项
- 掩码通常是单通道的整数张量,旋转、裁剪等变换的参数需要和图像完全一致,避免尺寸或位置偏差。
- 如果使用了归一化这类只针对图像的变换,不需要对掩码执行,避免破坏掩码的标注值。
- 在自定义Dataset的
__getitem__方法中调用同步变换逻辑,可以保证每个样本的图像和掩码变换一致。
通过上述方法,就可以在PyTorch中稳定实现图像与掩码的同步随机数据增强,保障训练数据的标注有效性,提升分割类任务的训练效果。