主要为记录自己学习实践mmsegmentation框架的过程,并顺便为一起学习的同学们提供参考,分享一下自己学习到的一些知识和所踩的坑,与大家共勉!
我个人主要是想要使用mmsegmentation框架训练自己的数据集,一开始跟着网上的教程使用了PspNet网络,但是可能由于数据集过小最后达到的效果不尽人意,因此考虑使用更新的、性能更好的SegFormer进行尝试,也是看到了SegFormer在各种数据集上的准确率都相较传统的神经网络有了较大提升,所以比较心动。
SegFormer在ADE20K数据集上的表现
那么让我们现在开始吧(这里默认大家都配置好mmsegmentation了):
首先对自己的数据集进行处理,我比较习惯于处理voc类型的数据集,因此这里主要介绍voc类型数据集的处理结构:
-------ImageSets-----------Segmentation----------------train.txt #训练集图片的文件名----------------trainval.txt #训练验证集图片的文件名----------------val.txt #验证集图片的文件名-------JPEGImages #存放训练与测试的所有图片文件-------SegmentationClass #存放图像分割结果图
然后是部署我们自己的配置文件,由于mmsegmentation的SegFormer并没有针对voc数据集的配置文件,因此需要我们自己对其进行修改以适配voc类型数据集
一、首先修改mmseg\.mim\configs\_base_\datasets\pascal_voc12.py文件(建议把mmseg文件夹复制到自己的项目文件夹下,以便于修改)
dataset_type = 'PascalVOCDataset'data_root = 'data/VOCdevkit/VOC2012' #修改为自己数据集的路径,推荐使用绝对路径
dict( type='MultiScaleFlipAug', # img_scale=(2048, 512), img_scale=(640, 640), #这里的图片大小按照自己数据集的图片进行修改 # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ])
二、然后修改mmseg\datasets\voc.py
主要将类别修改为自己的数据集类别以及想要为分割的各类别显示的颜色
CLASSES = ('sky', 'tree', 'road', 'grass', 'background') #写你实际的类别名就好了,最后再加上一个backgroundPALETTE = [[128, 128, 128], [129, 127, 38], [120, 69, 125], [53, 125, 34], [0, 11, 123]] #数量与类别数相对应
三、接着修改_base_\models\segformer.py(没有则创建一个)
# model settingsnorm_cfg = dict(type='BN', requires_grad=True) # 单卡改为BNfind_unused_parameters = Truemodel = dict( type='EncoderDecoder', pretrained=True, backbone=dict( type='MixVisionTransformer', in_channels=3, embed_dims=32, num_stages=4, num_layers=[2, 2, 2, 2], num_heads=[1, 2, 5, 8], patch_sizes=[7, 3, 3, 3], sr_ratios=[8, 4, 2, 1], out_indices=(0, 1, 2, 3), mlp_ratio=4, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1), decode_head=dict( type='SegformerHead', in_channels=[32, 64, 160, 256], in_index=[0, 1, 2, 3], channels=256, dropout_ratio=0.1, num_classes=2, # 与数据集类别数量相同 norm_cfg=norm_cfg, align_corners=False, loss_decode=dict(type='FocalLoss', use_sigmoid=True, loss_weight=1.0)), # focal loss使用更多 # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole'))
四、再创建总体配置文件
我这里将该文件拷贝到了项目文件夹中了,包括_base_文件夹,便于路径读取和修改,创建segformer_mit-b5.py总配置文件,然后更改继承的数据集类型:
_base_ = [ './_base_/models/segformer.py', './_base_/datasets/pascal_voc12_aug.py', './_base_/default_runtime.py', './_base_/schedules/schedule_160k.py']# model settingsnorm_cfg = dict(type='BN', requires_grad=True) # 单卡BNfind_unused_parameters = Truemodel = dict( type='EncoderDecoder', pretrained='mit_b5.pth', # 配置好pth路径 backbone=dict( type='MixVisionTransformer', in_channels=3, embed_dims=32, num_stages=4, num_layers=[2, 2, 2, 2], num_heads=[1, 2, 5, 8], patch_sizes=[7, 3, 3, 3], sr_ratios=[8, 4, 2, 1], out_indices=(0, 1, 2, 3), mlp_ratio=4, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1), decode_head=dict( type='SegformerHead', in_channels=[32, 64, 160, 256], in_index=[0, 1, 2, 3], channels=256, dropout_ratio=0.1, num_classes=2, norm_cfg=norm_cfg, align_corners=False, loss_decode=dict(type='FocalLoss', use_sigmoid=True, loss_weight=1.0)), # model training and testing settings train_cfg=dict(), test_cfg=dict(mode='whole'))# optimizeroptimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.), 'norm': dict(decay_mult=0.), 'head': dict(lr_mult=10.) }))lr_config = dict(_delete_=True, policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False)evaluation = dict(interval=16000, metric='mIoU')
五、下载对应的预训练模型
直接上链接:
链接:https://pan.baidu.com/s/1c-d5ghbVyLWqDvylJ24VSw” />
预测结果
后续大规模训练后会继续更新。。。
坑1:libpng warning: iCCP: known incorrect sRGB profile报错
原因是新版的libpng增强了ICC profiles检查,发出警告。此警告可以忽略,我在此也没有对其进行操作,可以使用其他方法(如skimage)读取的方式避免该类报错。
坑2:ValueError: expected 4D input (got 3D input)报错
这是一个困扰我许久的问题
通过上网才发现问题所在是因为使用了不正确的BatchNorm函数,快速解决的方法就是不需要在模型的backbone添加’norm_cfg’
未完待续。。。