找回密码
 立即注册
首页 业界区 业界 【语义分割专栏】:FCN实战篇(附上完整可运行的代码pyto ...

【语义分割专栏】:FCN实战篇(附上完整可运行的代码pytorch)

窟聿湎 2025-6-3 00:11:22
目录

  • 前言
  • FCN全流程代码

    • 模型搭建(model)
    • 数据处理(dataloader)
    • 评价指标(metric)
    • 训练流程(train)
    • 模型测试(test)

  • 效果图
  • 结语

前言

FCN原理篇讲解:【语义分割专栏】:FCN原理篇 - carpell - 博客园
代码地址,下载可复现:fouen6/FCN_semantic-segmentation
本篇文章收录于语义分割专栏,如果对语义分割领域感兴趣的,可以去看看专栏,会对经典的模型以及代码进行详细的讲解哦!其中会包含可复现的代码!
上篇文章已经带大家学习过了FCN的原理,相信大家对于原理应该有了比较深的了解。本文将会带大家去手动复现属于自己的一个语义分割模型。将会深入代码进行讲解,如果有讲错的地方欢迎大家批评指正!
其实所有的深度学习模型的搭建我认为可以总结成五部分:模型的构建,数据集的处理,评价指标的设定,训练流程,测试。其实感觉有点深度学习代码八股文的那种意思。本篇同样的也会按照这样的方式进行讲解,希望大家能够深入代码去进行了解学习。
请记住:只懂原理不懂代码,你就算有了很好的想法创新点,你也难以去实现,所以希望大家能够深入去了解,最好能够参考着本文自己复现一下。
1.png

FCN全流程代码

模型搭建(model)

我们这里根据原论文一样采用VGG作为我们的特征提取网络,如果你对VGG网络还不太了解的话,可以先去看看我对VGG网络的讲解。
我们都知道VGG采用了一些重复的结构,所以我们根据maxpool出现的位置将其划分为5个stage。这样我们可以同时用不同深度的VGG的网络,VGG11到VGG19都可以使用,因为其结构是一样的。
  1.                 backbone = get_backbone(backbone=backbone, pretrained=True)
  2.         features = list(backbone.features.children())
  3.         pool_indices = [i + 1 for i, layer in enumerate(features) if isinstance(layer, nn.MaxPool2d)]
  4.         pool_indices = [0] + pool_indices + [len(features)]
  5.         # 划分阶段
  6.         self.stage1 = nn.Sequential(*features[pool_indices[0]:pool_indices[1]])
  7.         self.stage2 = nn.Sequential(*features[pool_indices[1]:pool_indices[2]])
  8.         self.stage3 = nn.Sequential(*features[pool_indices[2]:pool_indices[3]])
  9.         self.stage4 = nn.Sequential(*features[pool_indices[3]:pool_indices[4]])
  10.         self.stage5 = nn.Sequential(*features[pool_indices[4]:pool_indices[5]])
复制代码
然后一个非常重要的,我们采用我们的双线性插值来初始化我们的反卷积,使用双线性插值来初始化,可以在训练初期保证模型有一个比较好的输出然后在通过训练调整。
  1. def _make_bilinear_weights(size,num_channels):
  2.     factor = (size+1)//2
  3.     if size % 2 == 1:
  4.         center = factor - 1
  5.     else:
  6.         center = factor - 0.5
  7.     og = torch.FloatTensor(size, size)
  8.     for i in range(size):
  9.         for j in range(size):
  10.             og[i, j] = (1-abs((i-center)/factor)) * (1-abs((j-center)/factor))
  11.     filter = torch.zeros(num_channels,num_channels,size,size)
  12.     for i in range(num_channels):
  13.         filter[i,i] = og
  14.     return filter
复制代码
最后我们先搭建我们的FCN32s模型,首先我们加载预训练的VGG16,当然别的VGG也行,你自己选择就行了。我们只需要其全连接层前的所有层,划分不同的stage,然后构建fcn检测头,先7x7的卷积,等效于对7×7感受野做全连接,输出4096通道,然后再1x1的卷积,提取语义特征,最后在1x1卷积输出每个空间位置上的类别分布(如21类)。FCN32s是直接从最后的层进行32倍上采样,当然了这样的结果就比较粗糙了。所以效果不会太好。
这里有个细节哈,x = x[:, :, :input_size[0], :input_size[1]],我们裁剪了保证初始大小,因为上采样过程中可能会造成图像的尺度超出一点点的,比如上采样后应该是224,然后最后是225,所以裁剪保证与初始大小相同。
  1. class FCN32s(nn.Module):
  2.     def __init__(self,num_classes = 21,backbone='vgg16'):
  3.         super(FCN32s, self).__init__()
  4.         self.num_classes = num_classes
  5.         backbone = get_backbone(backbone=backbone, pretrained=True)
  6.         features = list(backbone.features.children())
  7.         pool_indices = [i + 1 for i, layer in enumerate(features) if isinstance(layer, nn.MaxPool2d)]
  8.         pool_indices = [0] + pool_indices + [len(features)]
  9.         # 划分阶段
  10.         self.stage1 = nn.Sequential(*features[pool_indices[0]:pool_indices[1]])
  11.         self.stage2 = nn.Sequential(*features[pool_indices[1]:pool_indices[2]])
  12.         self.stage3 = nn.Sequential(*features[pool_indices[2]:pool_indices[3]])
  13.         self.stage4 = nn.Sequential(*features[pool_indices[3]:pool_indices[4]])
  14.         self.stage5 = nn.Sequential(*features[pool_indices[4]:pool_indices[5]])
  15.         self.fcn_head = nn.Sequential(
  16.             nn.Conv2d(512, 4096, kernel_size=7,padding=3),
  17.             nn.ReLU(inplace=True),
  18.             nn.Dropout(),
  19.             nn.Conv2d(4096,4096,kernel_size=1),
  20.             nn.ReLU(inplace=True),
  21.             nn.Dropout(),
  22.             nn.Conv2d(4096, self.num_classes, kernel_size=1),
  23.         )
  24.         self.upsample32 = nn.ConvTranspose2d(self.num_classes,self.num_classes,kernel_size = 64,stride = 32,padding = 16,bias = False)
  25.         for m in self.modules():
  26.             if isinstance(m, nn.ConvTranspose2d):
  27.                 m.weight.data.zero_()
  28.                 m.weight.data = _make_bilinear_weights(m.kernel_size[0], m.out_channels)
  29.     def forward(self, x):
  30.         input_size = x.size()[2:]
  31.         x = self.stage1(x)
  32.         x = self.stage2(x)
  33.         x = self.stage3(x)
  34.         x = self.stage4(x)
  35.         x = self.stage5(x)
  36.         x = self.fcn_head(x)
  37.         x = self.upsample32(x)
  38.         x = x[:, :, :input_size[0], :input_size[1]]
  39.         return x
复制代码
然后就是FCN16s了,我们通过将stage4的输出作为我们的pool4,同时我们将pool4经过卷积输出通道到变为类别分数,到时候方便跟最终的输出做跳跃连接。经过fcn_head输出的x先上采样两倍到跟pool4相同的shape,然后两者做跳跃连接相加后再上采样16倍到输入图像的shape大小。
  1. class FCN16s(nn.Module):
  2.     def __init__(self,num_classes = 21,backbone='vgg16'):
  3.         super(FCN16s, self).__init__()
  4.         self.num_classes = num_classes
  5.         backbone = get_backbone(backbone=backbone, pretrained=True)
  6.         features = list(backbone.features.children())
  7.         pool_indices = [i + 1 for i, layer in enumerate(features) if isinstance(layer, nn.MaxPool2d)]
  8.         pool_indices = [0] + pool_indices + [len(features)]
  9.         # 划分阶段
  10.         self.stage1 = nn.Sequential(*features[pool_indices[0]:pool_indices[1]])
  11.         self.stage2 = nn.Sequential(*features[pool_indices[1]:pool_indices[2]])
  12.         self.stage3 = nn.Sequential(*features[pool_indices[2]:pool_indices[3]])
  13.         self.stage4 = nn.Sequential(*features[pool_indices[3]:pool_indices[4]])
  14.         self.stage5 = nn.Sequential(*features[pool_indices[4]:pool_indices[5]])
  15.         self.fcn_head = nn.Sequential(
  16.             nn.Conv2d(512, 4096, kernel_size=7,padding=3),
  17.             nn.ReLU(inplace=True),
  18.             nn.Dropout(),
  19.             nn.Conv2d(4096,4096,kernel_size=1),
  20.             nn.ReLU(inplace=True),
  21.             nn.Dropout(),
  22.             nn.Conv2d(4096, self.num_classes, kernel_size=1),
  23.         )
  24.         self.pool4_score = nn.Conv2d(512,self.num_classes, kernel_size=1)
  25.         self.upsample2 = nn.ConvTranspose2d(self.num_classes,self.num_classes,kernel_size = 4,stride = 2,padding = 1,
  26.                                             bias = False)
  27.         self.upsample16 = nn.ConvTranspose2d(self.num_classes, self.num_classes, kernel_size=32, stride=16, padding=8,
  28.                                             bias=False)
  29.         for m in self.modules():
  30.             if isinstance(m, nn.ConvTranspose2d):
  31.                 m.weight.data.zero_()
  32.                 m.weight.data = _make_bilinear_weights(m.kernel_size[0], m.out_channels)
  33.     def forward(self, x):
  34.         input_size = x.size()[2:]
  35.         x = self.stage1(x)
  36.         x = self.stage2(x)
  37.         x = self.stage3(x)
  38.         x = self.stage4(x)
  39.         pool4 = x
  40.         x = self.stage5(pool4)
  41.         x = self.fcn_head(x)
  42.         x = self.upsample2(x)
  43.         pool4_score = self.pool4_score(pool4)
  44.         pool4_score = pool4_score[:, :, :x.size()[2], :x.size()[3]]
  45.         x = x + pool4_score
  46.         x = self.upsample16(x)
  47.         x = x[:, :, :input_size[0], :input_size[1]]
  48.         return x
复制代码
然后就是FCN8s了,我们通过将stage3和stage4的输出作为我们的pool3和pool4,同时我们将分别将pool3和pool4经过卷积输出将通道到变为类别分数,到时候方便跟最终的输出做跳跃连接。经过fcn_head输出的x先上采样两倍到跟pool4相同的shape,然后两者做跳跃连接相加后再上采样2倍到pool3的shape大小。再与pool3做跳跃连接相加,上采样8倍数到输出图像的shape大小。
  1. class FCN8s(nn.Module):
  2.     def __init__(self,num_classes = 21,backbone='vgg16'):
  3.         super(FCN8s, self).__init__()
  4.         self.num_classes = num_classes
  5.         backbone = get_backbone(backbone=backbone,pretrained=True)
  6.         features = list(backbone.features.children())
  7.         pool_indices = [i +1 for i, layer in enumerate(features) if isinstance(layer, nn.MaxPool2d)]
  8.         pool_indices = [0] + pool_indices + [len(features)]
  9.         # 划分阶段
  10.         self.stage1 = nn.Sequential(*features[pool_indices[0]:pool_indices[1]])
  11.         self.stage2 = nn.Sequential(*features[pool_indices[1]:pool_indices[2]])
  12.         self.stage3 = nn.Sequential(*features[pool_indices[2]:pool_indices[3]])
  13.         self.stage4 = nn.Sequential(*features[pool_indices[3]:pool_indices[4]])
  14.         self.stage5 = nn.Sequential(*features[pool_indices[4]:pool_indices[5]])
  15.         self.fcn_head = nn.Sequential(
  16.             nn.Conv2d(512, 4096, kernel_size=7,padding=3),
  17.             nn.ReLU(inplace=True),
  18.             nn.Dropout2d(),
  19.             nn.Conv2d(4096,4096,kernel_size=1),
  20.             nn.ReLU(inplace=True),
  21.             nn.Dropout2d(),
  22.             nn.Conv2d(4096, self.num_classes, kernel_size=1),
  23.         )
  24.         self.pool3_score = nn.Conv2d(256,self.num_classes, kernel_size=1)
  25.         self.pool4_score = nn.Conv2d(512, self.num_classes, kernel_size=1)
  26.         self.upsample2_1 = nn.ConvTranspose2d(self.num_classes,self.num_classes,kernel_size = 4,stride = 2,padding = 1,
  27.                                             bias = False)
  28.         self.upsample2_2 = nn.ConvTranspose2d(self.num_classes, self.num_classes, kernel_size=4, stride=2, padding=1,
  29.                                               bias=False)
  30.         self.upsample8 = nn.ConvTranspose2d(self.num_classes, self.num_classes, kernel_size=16, stride=8, padding=4,
  31.                                             bias=False)
  32.         for m in self.modules():
  33.             if isinstance(m, nn.ConvTranspose2d):
  34.                 m.weight.data.zero_()
  35.                 m.weight.data=_make_bilinear_weights(m.kernel_size[0],m.out_channels)
  36.     def forward(self, x):
  37.         input_size = x.size()[2:]
  38.         x = self.stage1(x)
  39.         x = self.stage2(x)
  40.         pool3 = self.stage3(x)
  41.         pool4 = self.stage4(pool3)
  42.         x = self.stage5(pool4)
  43.         x = self.fcn_head(x)
  44.         x = self.upsample2_1(x)
  45.         pool4_score = self.pool4_score(pool4)
  46.         pool4_score = pool4_score[:, :, :x.size()[2], :x.size()[3]]
  47.         x = x + pool4_score
  48.         x = self.upsample2_2(x)
  49.         pool3_score = self.pool3_score(pool3)
  50.         pool3_score = pool3_score[:, :, :x.size()[2], :x.size()[3]]
  51.         x = x + pool3_score
  52.         x = self.upsample8(x)
  53.         x = x[:, :, :input_size[0], :input_size[1]]
  54.         return x
复制代码
数据处理(dataloader)

数据集名称:VOC2012
数据集下载地址:The PASCAL Visual Object Classes Challenge 2012 (VOC2012)
2.png

在这里下载哈,2GB的那个。
这里我已经专门发了一篇博客对语义分割任务常用的数据集做了深入的介绍,已经具体讲解了其实现的处理代码。如果你对语义分割常用数据集有不了解的话,可以先去我的语义分割专栏中进行了解哦!!  我这里就直接附上代码了。
  1. import torch
  2. import numpy as np
  3. from PIL import Image
  4. from torch.utils.data import Dataset, DataLoader
  5. import os
  6. import random
  7. import torchvision.transforms as T
  8. VOC_CLASSES = [
  9.     'background','aeroplane','bicycle','bird','boat','bottle','bus',
  10.     'car','cat','chair','cow','diningtable','dog','horse',
  11.     'motorbike','person','potted plant','sheep','sofa','train','tv/monitor'
  12. ]
  13. class VOCSegmentation(Dataset):
  14.     def __init__(self, root, split='train', img_size=320, augment=True):
  15.         super(VOCSegmentation, self).__init__()
  16.         self.root = root
  17.         self.split = split
  18.         self.img_size = img_size
  19.         self.augment = augment
  20.         img_dir = os.path.join(root, 'JPEGImages')
  21.         mask_dir = os.path.join(root, 'SegmentationClass')
  22.         split_file = os.path.join(root, 'ImageSets', 'Segmentation', f'{split}.txt')
  23.         if not os.path.exists(split_file):
  24.             raise FileNotFoundError(split_file)
  25.         with open(split_file, 'r') as f:
  26.             file_names = [x.strip() for x in f.readlines()]
  27.         self.images = [os.path.join(img_dir, x + '.jpg') for x in file_names]
  28.         self.masks = [os.path.join(mask_dir, x + '.png') for x in file_names]
  29.         assert len(self.images) == len(self.masks)
  30.         print(f"✅ {split} set loaded: {len(self.images)} samples")
  31.         self.normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
  32.                                      std=[0.229, 0.224, 0.225])
  33.     def __getitem__(self, index):
  34.         img = Image.open(self.images[index]).convert('RGB')
  35.         mask = Image.open(self.masks[index])  # mask为P模式(0~20的类别)
  36.         # Resize
  37.         img = img.resize((self.img_size, self.img_size), Image.BILINEAR)
  38.         mask = mask.resize((self.img_size, self.img_size), Image.NEAREST)
  39.         # 转Tensor
  40.         img = T.functional.to_tensor(img)
  41.         mask = torch.from_numpy(np.array(mask)).long()  # 0~20
  42.         # 数据增强
  43.         if self.augment:
  44.             if random.random() > 0.5:
  45.                 img = T.functional.hflip(img)
  46.                 mask = T.functional.hflip(mask)
  47.             if random.random() > 0.5:
  48.                 img = T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)(img)
  49.         img = self.normalize(img)
  50.         return img, mask
  51.     def __len__(self):
  52.         return len(self.images)
  53. def get_dataloader(data_path, batch_size=4, img_size=320, num_workers=4):
  54.     train_dataset = VOCSegmentation(root=data_path, split='train', img_size=img_size, augment=True)
  55.     val_dataset = VOCSegmentation(root=data_path, split='val', img_size=img_size, augment=False)
  56.     train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True, num_workers=num_workers)
  57.     val_loader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size, pin_memory=True, num_workers=num_workers)
  58.     return train_loader, val_loader
复制代码
评价指标(metric)

我们这里语义分割采用的评价指标为:PA(像素准确率),CPA(类别像素准确率),MPA(类别平均像素准确率),IoU(交并比),mIoU(平均交并比),FWIoU(频率加权交并比),mF1(平均F1分数)。
这里我已经专门发了一篇博客对这些平均指标做了深入的介绍,已经具体讲解了其实现的代码。如果你对这些评价指标有不了解的话,可以先去我的语义分割专栏中进行了解哦!!  我这里就直接附上代码了。
  1. import numpy as np
  2. __all__ = ['SegmentationMetric']
  3. class SegmentationMetric(object):
  4.     def __init__(self, numClass):
  5.         self.numClass = numClass
  6.         self.confusionMatrix = np.zeros((self.numClass,) * 2)
  7.     def genConfusionMatrix(self, imgPredict, imgLabel):
  8.         mask = (imgLabel >= 0) & (imgLabel < self.numClass)
  9.         label = self.numClass * imgLabel[mask] + imgPredict[mask]
  10.         count = np.bincount(label, minlength=self.numClass ** 2)
  11.         confusionMatrix = count.reshape(self.numClass, self.numClass)
  12.         return confusionMatrix
  13.     def addBatch(self, imgPredict, imgLabel):
  14.         assert imgPredict.shape == imgLabel.shape
  15.         self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel)
  16.         return self.confusionMatrix
  17.     def pixelAccuracy(self):
  18.         acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
  19.         return acc
  20.     def classPixelAccuracy(self):
  21.         denominator = self.confusionMatrix.sum(axis=1)
  22.         denominator = np.where(denominator == 0, 1e-12, denominator)
  23.         classAcc = np.diag(self.confusionMatrix) / denominator
  24.         return classAcc
  25.     def meanPixelAccuracy(self):
  26.         classAcc = self.classPixelAccuracy()
  27.         meanAcc = np.nanmean(classAcc)
  28.         return meanAcc
  29.     def IntersectionOverUnion(self):
  30.         intersection = np.diag(self.confusionMatrix)
  31.         union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(
  32.             self.confusionMatrix)
  33.         union = np.where(union == 0, 1e-12, union)
  34.         IoU = intersection / union
  35.         return IoU
  36.     def meanIntersectionOverUnion(self):
  37.         mIoU = np.nanmean(self.IntersectionOverUnion())
  38.         return mIoU
  39.     def Frequency_Weighted_Intersection_over_Union(self):
  40.         denominator1 = np.sum(self.confusionMatrix)
  41.         denominator1 = np.where(denominator1 == 0, 1e-12, denominator1)
  42.         freq = np.sum(self.confusionMatrix, axis=1) / denominator1
  43.         denominator2 = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(
  44.             self.confusionMatrix)
  45.         denominator2 = np.where(denominator2 == 0, 1e-12, denominator2)
  46.         iu = np.diag(self.confusionMatrix) / denominator2
  47.         FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
  48.         return FWIoU
  49.     def classF1Score(self):
  50.         tp = np.diag(self.confusionMatrix)
  51.         fp = self.confusionMatrix.sum(axis=0) - tp
  52.         fn = self.confusionMatrix.sum(axis=1) - tp
  53.         precision = tp / (tp + fp + 1e-12)
  54.         recall = tp / (tp + fn + 1e-12)
  55.         f1 = 2 * precision * recall / (precision + recall + 1e-12)
  56.         return f1
  57.     def meanF1Score(self):
  58.         f1 = self.classF1Score()
  59.         mean_f1 = np.nanmean(f1)
  60.         return mean_f1
  61.     def reset(self):
  62.         self.confusionMatrix = np.zeros((self.numClass, self.numClass))
  63.     def get_scores(self):
  64.         scores = {
  65.             'Pixel Accuracy': self.pixelAccuracy(),
  66.             'Class Pixel Accuracy': self.classPixelAccuracy(),
  67.             'Intersection over Union': self.IntersectionOverUnion(),
  68.             'Class F1 Score': self.classF1Score(),
  69.             'Frequency Weighted Intersection over Union': self.Frequency_Weighted_Intersection_over_Union(),
  70.             'Mean Pixel Accuracy': self.meanPixelAccuracy(),
  71.             'Mean Intersection over Union(mIoU)': self.meanIntersectionOverUnion(),
  72.             'Mean F1 Score': self.meanF1Score()
  73.         }
  74.         return scores
复制代码
训练流程(train)

到这里,所有的前期准备都已经就绪,我们就要开始训练我们的模型了。
  1. def parse_arguments():
  2.     parser = argparse.ArgumentParser()
  3.     parser.add_argument('--data_root', type=str, default='./datasets/VOC2012', help='Dataset root path')
  4.     parser.add_argument('--classes_name', type=str, default='VOC', help='Dataset class names')
  5.     parser.add_argument('--backbone', type=str, default='vgg16', help='Backbone model')
  6.     parser.add_argument('--head', type=str, default='fcn8s', help='Segmentation head')
  7.     parser.add_argument('--num_classes', type=int, default=21, help='Number of classes')
  8.     parser.add_argument('--epochs', type=int, default=50, help='Epochs')
  9.     parser.add_argument('--lr', type=float, default=0.005, help='Learning rate')
  10.     parser.add_argument('--momentum', type=float, default=0.9, help='Momentum')
  11.     parser.add_argument('--weight-decay', type=float, default=1e-4, help='Weight decay')
  12.     parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
  13.     parser.add_argument('--checkpoint', type=str, default='./checkpoint', help='Checkpoint directory')
  14.     parser.add_argument('--resume', type=str, default=None, help='Resume checkpoint path')
  15.     return parser.parse_args()
复制代码
首先来看看我们的一些参数的设定,一般我们都是这样放在最前面,能够让人更加快速的了解其代码的一些核心参数设置。首先就是我们的数据集位置(),然后就是我们的数据集名称(classes_name),这个暂时没什么用,因为我们目前只用了VOC数据集,然后就是特征提取网络的选择(backbone),这里我们可以选择不同深度的VGG网络,检测模型的选择(head),我们可以选择不同的fcn模型,数据集的类别数(num_classes),训练epoch数,这个你设置大一点也行,因为我们会在训练过程中保存最好结果的模型的。学习率(lr),动量(momentum),权重衰减(weight-decay),这些都属于模型超参数,大家可以尝试不同的数值,多试试,就会有个大致的了解的,批量大小(batch_size)根据自己电脑性能来设置,一般都是为2的倍数,保存权重的文件夹(checkpoint),是否继续训练(resume)。
[code]def train(args):    if not os.path.exists(args.checkpoint):        os.makedirs(args.checkpoint)    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    n_gpu = torch.cuda.device_count()    print(f"Device: {device}, GPUs available: {n_gpu}")    # Dataloader    train_loader, val_loader = get_dataloader(args.data_root, batch_size=args.batch_size)    train_dataset_size = len(train_loader.dataset)    val_dataset_size = len(val_loader.dataset)    print(f"Train samples: {train_dataset_size}, Val samples: {val_dataset_size}")    # Model    model = get_model(args.head, backbone=args.backbone, num_classes=args.num_classes)    model.to(device)    # Loss + Optimizer + Scheduler    criterion = nn.CrossEntropyLoss(ignore_index=255)    #optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)    scaler = torch.cuda.amp.GradScaler()    # Resume    start_epoch = 0    best_miou = 0.0    if args.resume and os.path.isfile(args.resume):        print(f"Loading checkpoint '{args.resume}'")        checkpoint = torch.load(args.resume)        start_epoch = checkpoint['epoch']        best_miou = checkpoint['best_miou']        model.load_state_dict(checkpoint['model_state_dict'])        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])        print(f"Loaded checkpoint (epoch {start_epoch})")    # Training history    history = {        'train_loss': [],        'val_loss': [],        'pixel_accuracy': [],        'miou': []    }    print(f"
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
您需要登录后才可以回帖 登录 | 立即注册