图像识别模型训练 | 苹果花识别(unet 像素级识别)Apple/Strawflower 分割训练全流程(Windows 版)

图像识别模型训练 | 苹果花识别(unet 像素级识别)Apple/Strawflower 分割训练全流程(Windows 版)

王先生
2025-10-14 / 0 评论 / 4 阅读 / 正在检测是否收录...

Apple/Strawflower 分割训练全流程(Windows 版)


一、项目概述

数据集使用:https://agdatacommons.nal.usda.gov/articles/dataset/Data_from_Multi-species_fruit_flower_detection_using_a_refined_semantic_segmentation_network/24852636
上框资源中的 AppleA
  • 目标:利用 USDA 公开数据集,训练一个像素级「花/背景」二分类分割模型
  • 数据:174 张 2K 分辨率图像 + 130 张单通道掩码(255=花,0=背景)
    2k 分辨率原图
    2025-10-14T03:29:21.png
    单通道掩码
    2025-10-14T03:29:49.png
  • 方案:U-Net + ResNet34 预训练,Windows 本地 GPU/CPU 均可跑通
  • 语言:Python ≥3.8,PyTorch ≥1.12

二、目录结构

flower_seg/
├─ data/
│  ├─ images/        IMG_0248.JPG  …
│  └─ masks/         248.png  …
├─ checkpoints/      best.pth
├─ train.py          # 一站式训练脚本
└─ README.md         # 本文档

三、快速开始

  1. 下载数据集并放到对应文件夹

    数据集下载地址:https://agdatacommons.nal.usda.gov/articles/dataset/Data_from_Multi-species_fruit_flower_detection_using_a_refined_semantic_segmentation_network/24852636
    上框资源中的 AppleA
  2. 创建环境

    conda create -n unet
    conda activate unet
    pip install torch torchvision segmentation-models-pytorch albumentations tqdm
  3. 训练

    cd flower_seg
    python train.py
  4. 推理见「六、推理示例」

四、核心踩坑与解决

问题报错提示解决
Win 多进程RuntimeError: ...spawn...把训练代码包进 main() + if __name__ == '__main__':num_workers=0
通道检查失败...shape consistency...is_check_shapes=False
RandomCrop 越界Values for crop should be non negative...改用 A.Resize(512,512) 或先 PadIfNeeded
IoU 指标移除no attribute 'utils'手动计算:(inter+1e-7)/(union+1e-7)

五、训练脚本(train.py)

#  -*- coding: utf-8 -*-
import os, glob, random, cv2, torch
import segmentation_models_pytorch as smp
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

# ---------- 参数 ----------
DATA_DIR   = r'data'
IMAGE_DIR  = os.path.join(DATA_DIR, 'images')
MASK_DIR   = os.path.join(DATA_DIR, 'masks')
CHECK_DIR  = r'checkpoints'
os.makedirs(CHECK_DIR, exist_ok=True)

DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 2
EPOCHS     = 60
LR         = 1e-3
IMG_SIZE   = 512

# ---------- 数据集 ----------
class FlowerDS(Dataset):
    def __init__(self, img_paths, mask_paths, transform=None):
        self.imgs, self.masks, self.tf = img_paths, mask_paths, transform
    def __len__(self): return len(self.imgs)
    def __getitem__(self, idx):
        img  = cv2.cvtColor(cv2.imread(self.imgs[idx]), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks[idx], cv2.IMREAD_GRAYSCALE)
        if self.tf:
            res = self.tf(image=img, mask=mask)
            img, mask = res['image'], res['mask']
        return img, (mask > 127).long()

def get_paths():
    img_ext = ('*.jpg', '*.png', '*.JPG', '*.PNG')
    imgs_ok, masks_ok = [], []
    for ext in img_ext:
        for img_p in glob.glob(os.path.join(IMAGE_DIR, ext)):
            name   = os.path.basename(img_p)
            number = name.split('.')[0].split('_')[-1]
            mask_cand = glob.glob(os.path.join(MASK_DIR, f'{int(number)}.*'))
            if mask_cand:
                imgs_ok.append(img_p)
                masks_ok.append(mask_cand[0])
            else:
                print(f'[Skip] missing mask -> {name}')
    return sorted(imgs_ok), sorted(masks_ok)

# ---------- 增强 ----------
tf_train = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.RandomRotate90(),
    A.ColorJitter(0.1, 0.1, 0.1, 0.05),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
], is_check_shapes=False)

tf_val = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
], is_check_shapes=False)

# ---------- 训练 / 验证 ----------
def train_one_epoch(model, loader, loss_fn, optimizer, device, epoch):
    model.train()
    running_loss = 0.
    pbar = tqdm(loader, desc=f'Epoch {epoch}')
    for x, y in pbar:
        x, y = x.to(device), y.to(device).unsqueeze(1).float()
        pred = model(x)
        loss = loss_fn(pred, y)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        running_loss += loss.item()
        pbar.set_postfix(loss=loss.item())
    return running_loss / len(loader)

@torch.no_grad()
def validate(model, loader, loss_fn, device):
    model.eval()
    iou_sum = 0.
    for x, y in loader:
        x, y = x.to(device), y.to(device).unsqueeze(1).float()
        pred = torch.sigmoid(model(x)) > 0.5
        inter = (pred & y.bool()).sum()
        union = (pred | y.bool()).sum()
        iou_sum += (inter / (union + 1e-7)).item()
    return iou_sum / len(loader)

# ---------- main ----------
def main():
    all_imgs, all_masks = get_paths()
    if len(all_imgs) == 0:
        print('No valid pairs!'); return
    paired = list(zip(all_imgs, all_masks))
    random.seed(42); random.shuffle(paired)
    split = int(0.8 * len(paired))
    train_img, train_msk = zip(*paired[:split])
    val_img, val_msk = zip(*paired[split:])

    train_ds = FlowerDS(train_img, train_msk, tf_train)
    val_ds = FlowerDS(val_img, val_msk, tf_val)
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=False)
    val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)

    model = smp.Unet('resnet34', encoder_weights='imagenet', classes=1, activation=None).to(device)
    loss_fn = smp.losses.DiceLoss('binary')
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    best_iou = 0.0
    for epoch in range(1, EPOCHS + 1):
        train_loss = train_one_epoch(model, train_dl, loss_fn, optimizer, DEVICE, epoch)
        val_iou = validate(model, val_dl, loss_fn, DEVICE)
        scheduler.step()
        print(f'Epoch {epoch:02d}  |  train loss {train_loss:.4f}  |  val mIoU {val_iou:.4f}')
        if val_iou > best_iou:
            best_iou = val_iou
            torch.save(model.state_dict(), os.path.join(CHECK_DIR, 'best.pth'))
            print('  * best model saved')
    print('Training finished!')


if __name__ == '__main__':
    main()

六、推理示例

import cv2, torch, albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = smp.Unet('resnet34', classes=1, activation=None).to(device)
model.load_state_dict(torch.load(r'checkpoints\best.pth', map_location=device))
model.eval()

tf = A.Compose([
    A.Resize(512, 512),
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ToTensorV2()
])

img = cv2.cvtColor(cv2.imread('test.jpg'), cv2.COLOR_BGR2RGB)
x = tf(image=img)['image'].unsqueeze(0).to(device)
with torch.no_grad():
    mask = (torch.sigmoid(model(x)) > 0.5).cpu().numpy().squeeze(0).transpose(1, 2, 0)
cv2.imwrite('mask.png', mask * 255)

七、性能参考(RTX-3060 12 G)

阶段数值
训练 60 epoch≈ 3.5 min
最高验证 mIoU0.87
512×512 推理35 fps

八、后续优化

  • 数据:CutMix / Mosaic / 外采更多图
  • 模型:Mask2Former、SegFormer、EfficientNet-B3 backbone
  • 部署:TensorRT 量化、ONNXRuntime、OpenVINO

Happy Training! 🌼未完待续...

评论 (0)

取消