参考博文: https://blog.csdn.net/qq_37652891/article/details/123932772
数据集准备
遥感图像多类别语义分割,总共分为7类(包括背景)
image:
label_rgb
label(这里并不是全黑,其中的类别取值为
0,1,2,3,4,5,6
),此后的训练使用的也是这样的数据数据地址
百度云:https://pan.baidu.com/s/1zZHnZfBgVWxs6TJW4yjeeQ
提取码:2022
数据集处理
数据集的 image
和 label
,这个数据集应该提供了 rgb
格式标签和包含 0,1,2,3,4,5,6
值的标签, SwinUNet
使用的是包含 0,1,2,3,4,5,6
的标签图像;
1. 数据集
数据集存放在 SwinUNet
根目录下, image
中是原图像, label
中是标签图像(共7类,其标签取值为 0,1,2,3,4,5,6,7
);
如果使用其他数据集,要注意标签的取值。比如如果是二分类。即标签 0
或 255
,需要换成 0
或 1
—SwinUNet
2. 在 SwinUnet
根目录下创建 npz.py
文件,运行 npz.py
文件
import glob
import cv2
import numpy as np
import os
def npz(im, la, s):
images_path = im
labels_path = la
path2 = s
images = os.listdir(images_path)
for s in images:
image_path = os.path.join(images_path, s)
label_path = os.path.join(labels_path, s)
image = cv2.imread(image_path)
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
label = cv2.imread(label_path, flags=0)
np.savez(path2+s[:-4]+".npz",image=image,label=label)
npz('./img_datas/train/image/', './img_datas/train/label/', './data/Synapse/train_npz')
npz('./img_datas/test/image/', './img_datas/test/label/', './data/Synapse/test_vol_h5')
3. 在 SwinUnet
根目录下创建 txt.py
文件,运行 txt.py
文件
目的是生成 ./list/list_Synapse/train.txt
和 ./list/list_Synapse/test_vol.txt
文件
import os
def write_name(np, tx):
files = os.listdir(np)
f = open(tx, 'w')
for i in files:
name = i[:-4]+'\n'
f.write(name)
write_name('./data/Synapse/train_npz', './lists/lists_Synapse/train.txt')
write_name('./data/Synapse/test_vol_h5', './lists/lists_Synapse/test_vol.txt')
4. 下载预训练权重,放在 SwinUnet
目录下的 pretrained_ckpt
文件夹下
链接:https://pan.baidu.com/s/1-hYwJRlr95Fv08e9AEARww
提取码: 2022
; 修改网络
1. 修改 train.py
文件
比较重要的是 类别数量,其他视情况而定
; 2. 修改 ./datasets/dataset_synapse.py
文件
3. 修改 trainer.py
文件
此处不知道为什么
; 4. 运行代码
这些信息可以作为超参传入,如果不能,那么可以使用 default=
的方式写入默认值
如果设置好啦默认值,那么运行
python train.py
就可以啦Original: https://blog.csdn.net/weixin_44669966/article/details/125623961
Author: 我是一个小稻米
Title: 使用SwinUnet训练自己的数据集
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/614489/
转载文章受原作者版权保护。转载请注明原作者出处!