找回密码
 立即注册
首页 业界区 业界 【语义分割专栏】3:Segnet实战篇(附上完整可运行的代码 ...

【语义分割专栏】3:Segnet实战篇(附上完整可运行的代码pytorch)

度阡舅 2025-6-9 10:32:12
目录

  • 前言
  • Segnet全流程代码

    • 模型搭建(model)
    • 数据处理(dataloader)
    • 评价指标(metric)
    • 训练流程(train)
    • 模型测试(test)

  • 效果图
  • 结语

前言

Segnet原理篇讲解:【语义分割专栏】3:Segnet原理篇 - carpell - 博客园
代码地址,下载可复现:fouen6/Segnet_semantic-segmentation: 用于学习理解segnet原理
本篇文章收录于语义分割专栏,如果对语义分割领域感兴趣的,可以去看看专栏,会对经典的模型以及代码进行详细的讲解哦!其中会包含可复现的代码!(数据集文中提供了下载地址,下载不到可在评论区要取)
上篇文章已经带大家学习过了Segnet的原理,相信大家对于原理应该有了比较深的了解。本文将会带大家去手动复现属于自己的一个语义分割模型。将会深入代码进行讲解,如果有讲错的地方欢迎大家批评指正!
其实所有的深度学习模型的搭建我认为可以总结成五部分:模型的构建,数据集的处理,评价指标的设定,训练流程,测试。其实感觉有点深度学习代码八股文的那种意思。本篇同样的也会按照这样的方式进行讲解,希望大家能够深入代码去进行了解学习。
请记住:只懂原理不懂代码,你就算有了很好的想法创新点,你也难以去实现,所以希望大家能够深入去了解,最好能够参考着本文自己复现一下。
1.png

Segnet全流程代码

模型搭建(model)

首先是我们的crop函数,为什么需要用到这个,因为在测试的时候,我们不会对图像进行resize操作的,所以其就不一定是32的倍数,在下采样的过程中可能会出现从45->22的情况,但是上采样过程中就会变成22->44,这样就会造成shape的不匹配,所以需要对齐两者的shape大小。
  1. def crop(upsampled, bypass):
  2.     h1, w1 = upsampled.shape[2], upsampled.shape[3]
  3.     h2, w2 = bypass.shape[2], bypass.shape[3]
  4.     # 计算差值
  5.     deltah = h2 - h1
  6.     deltaw = w2 - w1
  7.     # 计算填充的起始和结束位置
  8.     # 对于高度
  9.     pad_top = deltah // 2
  10.     pad_bottom = deltah - pad_top
  11.     # 对于宽度
  12.     pad_left = deltaw // 2
  13.     pad_right = deltaw - pad_left
  14.     # 对 upsampled 进行中心填充
  15.     upsampled_padded = F.pad(upsampled, (pad_left, pad_right, pad_top, pad_bottom), "constant", 0)
  16.     return upsampled_padded
复制代码
然后就是我们的Segnet模型代码了。其实还是非常好理解的,其编码器的结构就是VGG的结构,只不过其在maxpooling的时候需要保存索引,然后就是解码器的结构,其实就是对编码器做个对称就行了。写好模型参数之后,非常重要的,记得要进行参数的初始化哈,这样能够利于之后的训练过程。
  1. class SegNet(nn.Module):
  2.     def __init__(self,num_classes=12):
  3.         super(SegNet, self).__init__()
  4.         self.encoder1 = nn.Sequential(
  5.             nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
  6.             nn.BatchNorm2d(64),
  7.             nn.ReLU(),
  8.             nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
  9.             nn.BatchNorm2d(64),
  10.             nn.ReLU(),
  11.         )
  12.         self.encoder2 = nn.Sequential(
  13.             nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
  14.             nn.BatchNorm2d(128),
  15.             nn.ReLU(),
  16.             nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),
  17.             nn.BatchNorm2d(128),
  18.             nn.ReLU(),
  19.         )
  20.         self.encoder3 = nn.Sequential(
  21.             nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1),
  22.             nn.BatchNorm2d(256),
  23.             nn.ReLU(),
  24.             nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
  25.             nn.BatchNorm2d(256),
  26.             nn.ReLU(),
  27.             nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
  28.             nn.BatchNorm2d(256),
  29.             nn.ReLU(),
  30.         )
  31.         self.encoder4 = nn.Sequential(
  32.             nn.Conv2d(256,512,kernel_size=3,stride=1,padding=1),
  33.             nn.BatchNorm2d(512),
  34.             nn.ReLU(),
  35.             nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
  36.             nn.BatchNorm2d(512),
  37.             nn.ReLU(),
  38.             nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
  39.             nn.BatchNorm2d(512),
  40.             nn.ReLU(),
  41.         )
  42.         self.encoder5 = nn.Sequential(
  43.             nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
  44.             nn.BatchNorm2d(512),
  45.             nn.ReLU(),
  46.             nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
  47.             nn.BatchNorm2d(512),
  48.             nn.ReLU(),
  49.             nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
  50.             nn.BatchNorm2d(512),
  51.             nn.ReLU(),
  52.         )
  53.         self.decoder1 = nn.Sequential(
  54.             nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
  55.             nn.BatchNorm2d(512),
  56.             nn.ReLU(),
  57.             nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
  58.             nn.BatchNorm2d(512),
  59.             nn.ReLU(),
  60.             nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
  61.             nn.BatchNorm2d(512),
  62.             nn.ReLU(),
  63.         )
  64.         self.decoder2 = nn.Sequential(
  65.             nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
  66.             nn.BatchNorm2d(512),
  67.             nn.ReLU(),
  68.             nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
  69.             nn.BatchNorm2d(512),
  70.             nn.ReLU(),
  71.             nn.Conv2d(512,256,kernel_size=3,stride=1,padding=1),
  72.             nn.BatchNorm2d(256),
  73.             nn.ReLU(),
  74.         )
  75.         self.decoder3 = nn.Sequential(
  76.             nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
  77.             nn.BatchNorm2d(256),
  78.             nn.ReLU(),
  79.             nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
  80.             nn.BatchNorm2d(256),
  81.             nn.ReLU(),
  82.             nn.Conv2d(256,128,kernel_size=3,stride=1,padding=1),
  83.             nn.BatchNorm2d(128),
  84.             nn.ReLU(),
  85.         )
  86.         self.decoder4 = nn.Sequential(
  87.             nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),
  88.             nn.BatchNorm2d(128),
  89.             nn.ReLU(),
  90.             nn.Conv2d(128,64,kernel_size=3,stride=1,padding=1),
  91.             nn.BatchNorm2d(64),
  92.             nn.ReLU(),
  93.         )
  94.         self.decoder5 = nn.Sequential(
  95.             nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
  96.             nn.BatchNorm2d(64),
  97.             nn.ReLU(),
  98.             nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
  99.             nn.BatchNorm2d(64),
  100.             nn.ReLU(),
  101.             nn.Conv2d(64,num_classes,kernel_size=1),
  102.         )
  103.         self.max_pool = nn.MaxPool2d(2,2,return_indices=True)
  104.         self.max_uppool = nn.MaxUnpool2d(2,2)
  105.         self.initialize_weights()
  106.     def initialize_weights(self):
  107.         for m in self.modules():
  108.             if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
  109.                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  110.                 if m.bias is not None:
  111.                     nn.init.constant_(m.bias, 0)
  112.             elif isinstance(m, nn.BatchNorm2d):
  113.                 nn.init.constant_(m.weight, 1)
  114.                 nn.init.constant_(m.bias, 0)
  115.     def forward(self, x):
  116.         x1 = self.encoder1(x)
  117.         x,pool_indices1 = self.max_pool(x1)
  118.         x2 = self.encoder2(x)
  119.         x,pool_indices2 = self.max_pool(x2)
  120.         x3 = self.encoder3(x)
  121.         x,pool_indices3 = self.max_pool(x3)
  122.         x4 = self.encoder4(x)
  123.         x,pool_indices4 = self.max_pool(x4)
  124.         x5 = self.encoder5(x)
  125.         x,pool_indices5 = self.max_pool(x5)
  126.         x = self.max_uppool(x,pool_indices5)
  127.         x = crop(x, x5)
  128.         x = self.decoder1(x)
  129.         x = self.max_uppool(x,pool_indices4)
  130.         x = crop(x, x4)
  131.         x = self.decoder2(x)
  132.         x = self.max_uppool(x,pool_indices3)
  133.         x = crop(x, x3)
  134.         x = self.decoder3(x)
  135.         x = self.max_uppool(x,pool_indices2)
  136.         x = crop(x, x2)
  137.         x = self.decoder4(x)
  138.         x = self.max_uppool(x,pool_indices1)
  139.         x = crop(x, x1)
  140.         x = self.decoder5(x)
  141.         return x
复制代码
数据处理(dataloader)

数据集名称:CamVid
数据集下载地址:Object Recognition in Video Dataset
2.png

在这里进行下载,CamVid数据集有两种,一种是官方的就是上述的下载地址的,总共有32种类别,划分的会更加的细致。但是一般官网的太难打开了,所以我们可以通过Kaggle中的CamVid (Cambridge-Driving Labeled Video Database)进行下载。
还有一种就是11类别的(不包括背景),会将一些语义相近的内容进行合并,就划分的没有这么细致,任务难度也会比较低一些。(如果你在网上找不到的话,可以在评论区发言或是私聊我要取)
CamVid 数据集主要用于自动驾驶场景中的语义分割,包含驾驶场景中的道路、交通标志、车辆等类别的标注图像。该数据集旨在推动自动驾驶系统在道路场景中的表现。
数据特点

  • 图像数量:包括701帧视频序列图像,分为训练集、验证集和测试集。
  • 类别:包含32个类别(也有包含11个类别的),包括道路、建筑物、车辆、行人等。
  • 挑战:由于数据集主要来自城市交通场景,因此面临着动态变化的天气、光照、交通密度等挑战
这里我已经专门发了一篇博客对语义分割任务常用的数据集做了深入的介绍,已经具体讲解了其实现的处理代码。如果你对语义分割常用数据集有不了解的话,可以先去我的语义分割专栏中进行了解哦!!  我这里就直接附上代码了。
  1. import os
  2. from PIL import Image
  3. import albumentations as A
  4. from albumentations.pytorch.transforms import ToTensorV2
  5. from torch.utils.data import Dataset, DataLoader
  6. import numpy as np
  7. import torch
  8. # 11类
  9. Cam_CLASSES = [ "Unlabelled","Sky","Building","Pole",
  10.                 "Road","Sidewalk", "Tree","SignSymbol",
  11.                 "Fence","Car","Pedestrian","Bicyclist"]
  12. # 用于做可视化
  13. Cam_COLORMAP = [
  14.     [0, 0, 0],[128, 128, 128],[128, 0, 0],[192, 192, 128],
  15.     [128, 64, 128],[0, 0, 192],[128, 128, 0],[192, 128, 128],
  16.     [64, 64, 128],[64, 0, 128],[64, 64, 0],[0, 128, 192]
  17. ]
  18. # 转换RGB mask为类别id的函数
  19. def mask_to_class(mask):
  20.     mask_class = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int64)
  21.     for idx, color in enumerate(Cam_COLORMAP):
  22.         color = np.array(color)
  23.         # 每个像素和当前颜色匹配
  24.         matches = np.all(mask == color, axis=-1)
  25.         mask_class[matches] = idx
  26.     return mask_class
  27. class CamVidDataset(Dataset):
  28.     def __init__(self, image_dir, label_dir):
  29.         self.image_dir = image_dir
  30.         self.label_dir = label_dir
  31.         self.transform = A.Compose([
  32.             A.Resize(224, 224),
  33.             A.HorizontalFlip(),
  34.             A.VerticalFlip(),
  35.             A.Normalize(),
  36.             ToTensorV2(),
  37.         ])
  38.         self.images = sorted(os.listdir(image_dir))
  39.         self.labels = sorted(os.listdir(label_dir))
  40.         assert len(self.images) == len(self.labels), "Images and labels count mismatch!"
  41.     def __len__(self):
  42.         return len(self.images)
  43.     def __getitem__(self, idx):
  44.         img_path = os.path.join(self.image_dir, self.images[idx])
  45.         label_path = os.path.join(self.label_dir, self.labels[idx])
  46.         image = np.array(Image.open(img_path).convert("RGB"))
  47.         label_rgb = np.array(Image.open(label_path).convert("RGB"))
  48.         # RGB转类别索引
  49.         mask = mask_to_class(label_rgb)
  50.         #mask = torch.from_numpy(np.array(mask)).long()
  51.         # Albumentations 需要 (H, W, 3) 和 (H, W)
  52.         transformed = self.transform(image=image, mask=mask)
  53.         return transformed['image'], transformed['mask'].long()
  54. def get_dataloader(data_path, batch_size=4, num_workers=4):
  55.     train_dir = os.path.join(data_path, 'train')
  56.     val_dir = os.path.join(data_path, 'val')
  57.     trainlabel_dir = os.path.join(data_path, 'train_labels')
  58.     vallabel_dir = os.path.join(data_path, 'val_labels')
  59.     train_dataset = CamVidDataset(train_dir, trainlabel_dir)
  60.     val_dataset = CamVidDataset(val_dir, vallabel_dir)
  61.     train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True, num_workers=num_workers)
  62.     val_loader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size, pin_memory=True, num_workers=num_workers)
  63.     return train_loader, val_loader
复制代码
评价指标(metric)

我们这里语义分割采用的评价指标为:PA(像素准确率),CPA(类别像素准确率),MPA(类别平均像素准确率),IoU(交并比),mIoU(平均交并比),FWIoU(频率加权交并比),mF1(平均F1分数)。
这里我已经专门发了一篇博客对这些平均指标做了深入的介绍,已经具体讲解了其实现的代码。如果你对这些评价指标有不了解的话,可以先去我的语义分割专栏中进行了解哦!!  我这里就直接附上代码了。
  1. import numpy as np
  2. __all__ = ['SegmentationMetric']
  3. class SegmentationMetric(object):
  4.     def __init__(self, numClass):
  5.         self.numClass = numClass
  6.         self.confusionMatrix = np.zeros((self.numClass,) * 2)
  7.     def genConfusionMatrix(self, imgPredict, imgLabel):
  8.         mask = (imgLabel >= 0) & (imgLabel < self.numClass)
  9.         label = self.numClass * imgLabel[mask] + imgPredict[mask]
  10.         count = np.bincount(label, minlength=self.numClass ** 2)
  11.         confusionMatrix = count.reshape(self.numClass, self.numClass)
  12.         return confusionMatrix
  13.     def addBatch(self, imgPredict, imgLabel):
  14.         assert imgPredict.shape == imgLabel.shape
  15.         self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel)
  16.         return self.confusionMatrix
  17.     def pixelAccuracy(self):
  18.         acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
  19.         return acc
  20.     def classPixelAccuracy(self):
  21.         denominator = self.confusionMatrix.sum(axis=1)
  22.         denominator = np.where(denominator == 0, 1e-12, denominator)
  23.         classAcc = np.diag(self.confusionMatrix) / denominator
  24.         return classAcc
  25.     def meanPixelAccuracy(self):
  26.         classAcc = self.classPixelAccuracy()
  27.         meanAcc = np.nanmean(classAcc)
  28.         return meanAcc
  29.     def IntersectionOverUnion(self):
  30.         intersection = np.diag(self.confusionMatrix)
  31.         union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(
  32.             self.confusionMatrix)
  33.         union = np.where(union == 0, 1e-12, union)
  34.         IoU = intersection / union
  35.         return IoU
  36.     def meanIntersectionOverUnion(self):
  37.         mIoU = np.nanmean(self.IntersectionOverUnion())
  38.         return mIoU
  39.     def Frequency_Weighted_Intersection_over_Union(self):
  40.         denominator1 = np.sum(self.confusionMatrix)
  41.         denominator1 = np.where(denominator1 == 0, 1e-12, denominator1)
  42.         freq = np.sum(self.confusionMatrix, axis=1) / denominator1
  43.         denominator2 = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(
  44.             self.confusionMatrix)
  45.         denominator2 = np.where(denominator2 == 0, 1e-12, denominator2)
  46.         iu = np.diag(self.confusionMatrix) / denominator2
  47.         FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
  48.         return FWIoU
  49.     def classF1Score(self):
  50.         tp = np.diag(self.confusionMatrix)
  51.         fp = self.confusionMatrix.sum(axis=0) - tp
  52.         fn = self.confusionMatrix.sum(axis=1) - tp
  53.         precision = tp / (tp + fp + 1e-12)
  54.         recall = tp / (tp + fn + 1e-12)
  55.         f1 = 2 * precision * recall / (precision + recall + 1e-12)
  56.         return f1
  57.     def meanF1Score(self):
  58.         f1 = self.classF1Score()
  59.         mean_f1 = np.nanmean(f1)
  60.         return mean_f1
  61.     def reset(self):
  62.         self.confusionMatrix = np.zeros((self.numClass, self.numClass))
  63.     def get_scores(self):
  64.         scores = {
  65.             'Pixel Accuracy': self.pixelAccuracy(),
  66.             'Class Pixel Accuracy': self.classPixelAccuracy(),
  67.             'Intersection over Union': self.IntersectionOverUnion(),
  68.             'Class F1 Score': self.classF1Score(),
  69.             'Frequency Weighted Intersection over Union': self.Frequency_Weighted_Intersection_over_Union(),
  70.             'Mean Pixel Accuracy': self.meanPixelAccuracy(),
  71.             'Mean Intersection over Union(mIoU)': self.meanIntersectionOverUnion(),
  72.             'Mean F1 Score': self.meanF1Score()
  73.         }
  74.         return scores
复制代码
训练流程(train)

到这里,所有的前期准备都已经就绪,我们就要开始训练我们的模型了。
  1. def parse_arguments():
  2.     parser = argparse.ArgumentParser()
  3.     parser.add_argument('--data_root', type=str, default='../../data/CamVid/CamVid(11)', help='Dataset root path')
  4.     parser.add_argument('--data_name', type=str, default='CamVid', help='Dataset class names')
  5.     parser.add_argument('--model', type=str, default='Segnet', help='Segmentation model')
  6.     parser.add_argument('--num_classes', type=int, default=12, help='Number of classes')
  7.     parser.add_argument('--epochs', type=int, default=50, help='Epochs')
  8.     parser.add_argument('--lr', type=float, default=0.005, help='Learning rate')
  9.     parser.add_argument('--momentum', type=float, default=0.9, help='Momentum')
  10.     parser.add_argument('--weight-decay', type=float, default=1e-4, help='Weight decay')
  11.     parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
  12.     parser.add_argument('--checkpoint', type=str, default='./checkpoint', help='Checkpoint directory')
  13.     parser.add_argument('--resume', type=str, default=None, help='Resume checkpoint path')
  14.     return parser.parse_args()
复制代码
首先来看看我们的一些参数的设定,一般我们都是这样放在最前面,能够让人更加快速的了解其代码的一些核心参数设置。首先就是我们的数据集位置(data_root),然后就是我们的数据集名称(classes_name),这个暂时没什么用,因为我们目前只用了CamVid数据集,然后就是检测模型的选择(model),我们选择Segnet模型,数据集的类别数(num_classes),训练epoch数,这个你设置大一点也行,因为我们会在训练过程中保存最好结果的模型的。学习率(lr),动量(momentum),权重衰减(weight-decay),这些都属于模型超参数,大家可以尝试不同的数值,多试试,就会有个大致的了解的,批量大小(batch_size)根据自己电脑性能来设置,一般都是为2的倍数,保存权重的文件夹(checkpoint),是否继续训练(resume)。
[code]def train(args):    if not os.path.exists(args.checkpoint):        os.makedirs(args.checkpoint)    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    n_gpu = torch.cuda.device_count()    print(f"Device: {device}, GPUs available: {n_gpu}")    # Dataloader    train_loader, val_loader = get_dataloader(args.data_root, batch_size=args.batch_size)    train_dataset_size = len(train_loader.dataset)    val_dataset_size = len(val_loader.dataset)    print(f"Train samples: {train_dataset_size}, Val samples: {val_dataset_size}")    # Model    model = get_model(num_classes=args.num_classes)    model.to(device)    # Loss + Optimizer + Scheduler    criterion = nn.CrossEntropyLoss(ignore_index=0)    #optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)    scaler = torch.cuda.amp.GradScaler()    # Resume    start_epoch = 0    best_miou = 0.0    if args.resume and os.path.isfile(args.resume):        print(f"Loading checkpoint '{args.resume}'")        checkpoint = torch.load(args.resume)        start_epoch = checkpoint['epoch']        best_miou = checkpoint['best_miou']        model.load_state_dict(checkpoint['model_state_dict'])        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])        print(f"Loaded checkpoint (epoch {start_epoch})")    # Training history    history = {        'train_loss': [],        'val_loss': [],        'pixel_accuracy': [],        'miou': []    }    print(f"
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
您需要登录后才可以回帖 登录 | 立即注册