Apple/Strawflower 分割训练全流程(Windows 版)
一、项目概述
数据集使用:https://agdatacommons.nal.usda.gov/articles/dataset/Data_from_Multi-species_fruit_flower_detection_using_a_refined_semantic_segmentation_network/24852636
上框资源中的 AppleA
- 目标:利用 USDA 公开数据集,训练一个像素级「花/背景」二分类分割模型
- 数据:174 张 2K 分辨率图像 + 130 张单通道掩码(255=花,0=背景)
2k 分辨率原图
单通道掩码 - 方案:U-Net + ResNet34 预训练,Windows 本地 GPU/CPU 均可跑通
- 语言:Python ≥3.8,PyTorch ≥1.12
二、目录结构
flower_seg/
├─ data/
│ ├─ images/ IMG_0248.JPG …
│ └─ masks/ 248.png …
├─ checkpoints/ best.pth
├─ train.py # 一站式训练脚本
└─ README.md # 本文档
三、快速开始
下载数据集并放到对应文件夹
数据集下载地址:https://agdatacommons.nal.usda.gov/articles/dataset/Data_from_Multi-species_fruit_flower_detection_using_a_refined_semantic_segmentation_network/24852636
上框资源中的 AppleA
创建环境
conda create -n unet conda activate unet pip install torch torchvision segmentation-models-pytorch albumentations tqdm
训练
cd flower_seg python train.py
- 推理见「六、推理示例」
四、核心踩坑与解决
问题 | 报错提示 | 解决 |
---|---|---|
Win 多进程 | RuntimeError: ...spawn... | 把训练代码包进 main() + if __name__ == '__main__': ;num_workers=0 |
通道检查失败 | ...shape consistency... | is_check_shapes=False |
RandomCrop 越界 | Values for crop should be non negative... | 改用 A.Resize(512,512) 或先 PadIfNeeded |
IoU 指标移除 | no attribute 'utils' | 手动计算:(inter+1e-7)/(union+1e-7) |
五、训练脚本(train.py)
# -*- coding: utf-8 -*-
import os, glob, random, cv2, torch
import segmentation_models_pytorch as smp
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
# ---------- 参数 ----------
DATA_DIR = r'data'
IMAGE_DIR = os.path.join(DATA_DIR, 'images')
MASK_DIR = os.path.join(DATA_DIR, 'masks')
CHECK_DIR = r'checkpoints'
os.makedirs(CHECK_DIR, exist_ok=True)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 2
EPOCHS = 60
LR = 1e-3
IMG_SIZE = 512
# ---------- 数据集 ----------
class FlowerDS(Dataset):
def __init__(self, img_paths, mask_paths, transform=None):
self.imgs, self.masks, self.tf = img_paths, mask_paths, transform
def __len__(self): return len(self.imgs)
def __getitem__(self, idx):
img = cv2.cvtColor(cv2.imread(self.imgs[idx]), cv2.COLOR_BGR2RGB)
mask = cv2.imread(self.masks[idx], cv2.IMREAD_GRAYSCALE)
if self.tf:
res = self.tf(image=img, mask=mask)
img, mask = res['image'], res['mask']
return img, (mask > 127).long()
def get_paths():
img_ext = ('*.jpg', '*.png', '*.JPG', '*.PNG')
imgs_ok, masks_ok = [], []
for ext in img_ext:
for img_p in glob.glob(os.path.join(IMAGE_DIR, ext)):
name = os.path.basename(img_p)
number = name.split('.')[0].split('_')[-1]
mask_cand = glob.glob(os.path.join(MASK_DIR, f'{int(number)}.*'))
if mask_cand:
imgs_ok.append(img_p)
masks_ok.append(mask_cand[0])
else:
print(f'[Skip] missing mask -> {name}')
return sorted(imgs_ok), sorted(masks_ok)
# ---------- 增强 ----------
tf_train = A.Compose([
A.Resize(IMG_SIZE, IMG_SIZE),
A.HorizontalFlip(p=0.5),
A.RandomRotate90(),
A.ColorJitter(0.1, 0.1, 0.1, 0.05),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
], is_check_shapes=False)
tf_val = A.Compose([
A.Resize(IMG_SIZE, IMG_SIZE),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
], is_check_shapes=False)
# ---------- 训练 / 验证 ----------
def train_one_epoch(model, loader, loss_fn, optimizer, device, epoch):
model.train()
running_loss = 0.
pbar = tqdm(loader, desc=f'Epoch {epoch}')
for x, y in pbar:
x, y = x.to(device), y.to(device).unsqueeze(1).float()
pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad(); loss.backward(); optimizer.step()
running_loss += loss.item()
pbar.set_postfix(loss=loss.item())
return running_loss / len(loader)
@torch.no_grad()
def validate(model, loader, loss_fn, device):
model.eval()
iou_sum = 0.
for x, y in loader:
x, y = x.to(device), y.to(device).unsqueeze(1).float()
pred = torch.sigmoid(model(x)) > 0.5
inter = (pred & y.bool()).sum()
union = (pred | y.bool()).sum()
iou_sum += (inter / (union + 1e-7)).item()
return iou_sum / len(loader)
# ---------- main ----------
def main():
all_imgs, all_masks = get_paths()
if len(all_imgs) == 0:
print('No valid pairs!'); return
paired = list(zip(all_imgs, all_masks))
random.seed(42); random.shuffle(paired)
split = int(0.8 * len(paired))
train_img, train_msk = zip(*paired[:split])
val_img, val_msk = zip(*paired[split:])
train_ds = FlowerDS(train_img, train_msk, tf_train)
val_ds = FlowerDS(val_img, val_msk, tf_val)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=False)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)
model = smp.Unet('resnet34', encoder_weights='imagenet', classes=1, activation=None).to(device)
loss_fn = smp.losses.DiceLoss('binary')
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
best_iou = 0.0
for epoch in range(1, EPOCHS + 1):
train_loss = train_one_epoch(model, train_dl, loss_fn, optimizer, DEVICE, epoch)
val_iou = validate(model, val_dl, loss_fn, DEVICE)
scheduler.step()
print(f'Epoch {epoch:02d} | train loss {train_loss:.4f} | val mIoU {val_iou:.4f}')
if val_iou > best_iou:
best_iou = val_iou
torch.save(model.state_dict(), os.path.join(CHECK_DIR, 'best.pth'))
print(' * best model saved')
print('Training finished!')
if __name__ == '__main__':
main()
六、推理示例
import cv2, torch, albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = smp.Unet('resnet34', classes=1, activation=None).to(device)
model.load_state_dict(torch.load(r'checkpoints\best.pth', map_location=device))
model.eval()
tf = A.Compose([
A.Resize(512, 512),
A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
ToTensorV2()
])
img = cv2.cvtColor(cv2.imread('test.jpg'), cv2.COLOR_BGR2RGB)
x = tf(image=img)['image'].unsqueeze(0).to(device)
with torch.no_grad():
mask = (torch.sigmoid(model(x)) > 0.5).cpu().numpy().squeeze(0).transpose(1, 2, 0)
cv2.imwrite('mask.png', mask * 255)
七、性能参考(RTX-3060 12 G)
阶段 | 数值 |
---|---|
训练 60 epoch | ≈ 3.5 min |
最高验证 mIoU | 0.87 |
512×512 推理 | 35 fps |
八、后续优化
- 数据:CutMix / Mosaic / 外采更多图
- 模型:Mask2Former、SegFormer、EfficientNet-B3 backbone
- 部署:TensorRT 量化、ONNXRuntime、OpenVINO
Happy Training! 🌼未完待续...
评论 (0)