码字不易,收藏之余,别忘了给我点个赞吧!
———Start
官方代码:https://github.com/Beckschen/TransUNet
目的:训练5个类别的汽车部件分割任务(测试在另一篇博客中)
实现效果:
; 1. github下载代码,并解压。
项目里的文件可能跟你下载的不一样,不急后面会讲到!
; 2. 配置数据集(尽最大努力还原官方数据集的格式)。
通常自己手上的数据集分images和labels文件夹,分别存放着原始图像和对应的mask图像,如下图所示; mask图像中的像素有0,1,2,3,4 分别代表背景,车身,轮子,车灯,窗户,一共五个类别,所以这里显示全黑色,肉眼看不出差别!通过阅读官方读取数据的代码,我们需要将一张图像和其对应的标签合并转化成一个.npz文件.
官方数据集格式,data文件夹,Synapse文件夹,test_vol_h5文件夹,train_npz文件夹手动创建!
转化数据集的代码如下,会将images中的图像和labels中的标签生成一个.npz文件。
def npz():
path = r'G:\dataset\car-segmentation\train\images\*.png'
path2 = r'G:\dataset\Unet\TransUnet-ori\data\Synapse\train_npz\\'
for i,img_path in enumerate(glob.glob(path)):
image = cv2.imread(img_path)
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
label_path = img_path.replace('images','labels')
label = cv2.imread(label_path,flags=0)
np.savez(path2+str(i),image=image,label=label)
print('------------',i)
print('ok')
生成的文件在 data\Synapse\train_npz文件夹中,如下图,也可以自己定义生成的路径,然后把文件复制到data\Synapse\train_npz文件中。
data\Synapse\train_npz文件夹中存放的是训练集样本,按照同样的方式生成测试集样本,存放在data\Synapse\test_vol_h5文件夹中。
我的训练集203个样本,测试集3个样本。npz文件生成完成之后,找到train.txt和test_vol.txt,手动将文件里面的内容清空,split_data.py这个文件直接无视。自己写一个函数读取train_npz中所有的文件名称,然后将文件名称写入train.txt文件,一个名称一行,如下图所示。同理可完成test_vol.txt文件制作。
至此,数据集制作完毕!!!代码会先去train.txt文件中读取训练样本的名称,然后根据名称再去train_npz文件夹下读取npz文件。所以每一步都很重要,必须正确!
3. 下载预训练权重
进入网站后,点击imagenet21k文件夹。
下载这个权重文件即可。
手动创建如下多个文件夹,存放刚刚下载完毕的权重,注意名称跟我的保持一致!
至此,预训练权重已下载完毕。
; 4. 修改读取文件的方法
找到datasets/dataset_synapse.py文件中的Synapse_dataset类,修改__getitem__函数。
def __getitem__(self, idx):
if self.split == "train":
slice_name = self.sample_list[idx].strip('\n')
data_path = self.data_dir+"/"+slice_name+'.npz'
data = np.load(data_path)
image, label = data['image'], data['label']
else:
slice_name = self.sample_list[idx].strip('\n')
data_path = self.data_dir+"/"+slice_name+'.npz'
data = np.load(data_path)
image, label = data['image'], data['label']
image = torch.from_numpy(image.astype(np.float32))
image = image.permute(2,0,1)
label = torch.from_numpy(label.astype(np.float32))
sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
sample['case_name'] = self.sample_list[idx].strip('\n')
return sample
找到datasets/dataset_synapse.py文件中的RandomGenerator类,修改__call__函数。
def __call__(self, sample):
image, label = sample['image'], sample['label']
if random.random() > 0.5:
image, label = random_rot_flip(image, label)
elif random.random() > 0.5:
image, label = random_rotate(image, label)
x, y,_ = image.shape
if x != self.output_size[0] or y != self.output_size[1]:
image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y,1), order=3)
label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
image = torch.from_numpy(image.astype(np.float32))
image = image.permute(2,0,1)
label = torch.from_numpy(label.astype(np.float32))
sample = {'image': image, 'label': label.long()}
return sample
至此,数据读取的部分已经修改完毕!
5. 配置训练参数
认真检查各个参数是否正确,这里的路径都是 ‘./'(当前目录下),不是”…/”,训练时,batch_size通常大于1,我这里设置有误!类别数可根据你的任务定!
图片大小设置,越大越耗显存。
; 6. 修改trainer.py文件
设置trainer.py文件中的DataLoader函数中的num_workers=0
至此,所有代码修改完毕!
总结:以上修改内容针对彩色图像的分割任务, 由于仅文字表述某些操作存在局限性,故只能简略应答,有任何问题可下方留言评论。
Original: https://blog.csdn.net/qq_37652891/article/details/123465472
Author: 小小小MaYi
Title: TransUnet官方代码训练自己数据集(彩色RGB3通道图像的分割)
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/668470/
转载文章受原作者版权保护。转载请注明原作者出处!