STANet_pytorch代码问题汇总、附上裁剪图片代码(有问留言必答)

STANet_pytorch代码问题汇总、附上裁剪图片代码

*
一、 STANet
二、问题汇总与解答(如列不全,请留言)

+ 1.安装虚拟环境与相关库的的条件:
+ 2. 查看代码中的readme.md文件(里面有跑代码的方式、模型与数据集链接)
+ 3. python demo.py 的问题
+ 4.跑BAM、PAM代码中常见的Out of memory 2 (64、256)GiB 问题
三、裁剪代码实现

+ 说明:

一、 STANet

因有一部分实验用到STANet网络,在网上找到相应的代码,花了大概一周一步步跳入坑、填坑的过程,苦于将其跑通,遂记录如下心得,希望能够帮助有需要的小伙伴避开”雷区”!

  • 文章源于:
  • 代码源于:
  • 大致看了一下该篇文章,网上有很多解读的博客,不做过多介绍,简而言之,该文章通过利用自注意力机制模块(BAM)和(多个BAM集成的PAM块),对遥感影像进行特征提取与训练, 通过对比两张不同时期的遥感图像,以深度学习的方法训练模型,最后能够”自动比对”找出同一区域,不同时间的变化情况。下图是STANet文章的截图。
  • 文章能够显著检测出遥感影像中变化的建筑物,可以应用于违章建筑拓展监测、乡村扶贫振兴和生态移民居住保障的风貌变化程度。
  • 相关的数据集包括(train、每train一轮epch之后紧接着验证val集,还有训练结束之后,将保存的model进行测试的test集 (PS: 文章代码的测试部分,称为val,python val.py 就是测试,而不是验证))。
  • 每一个数据集中包括:
  • ———–| A:前一段时间的遥感图像(1024 * 1024);
  • ———–| B:后一段时间的相同区域的遥感图像(1024*1024) ;
  • ———–| label:标注好两幅遥感图像之间存在的变化,因为数据中考虑一个类别(建筑物)的变化情况,以二值图形式(黑白)进行展示(1024*1024))。
    !命名一定要一致!

; 二、问题汇总与解答(如列不全,请留言)

  • *1.安装虚拟环境与相关库的的条件:

STANet_pytorch代码问题汇总、附上裁剪图片代码(有问留言必答)
  • visdom=0.1.8.1 或者 修改可视化版本visdom=0.1.8.8;
    *
不然可能在测试的时候,会出现:AssertionError: X and Y should be the same shape
  • scipy=1.1.0:因为1.2.0版本的scipy没有 imread,也会报错。

2. 查看代码中的readme.md文件(里面有跑代码的方式、模型与数据集链接)

  • 如果开始想python demo.py,先下载文章训练好的模型、LEVIR-CD数据集(README.md中有百度网盘、谷歌云盘这两种形式的链接)添加到相应的位置。
  • 在运行代码过程中,多半会出现no file 报错,就按照报错的提示,

3. python demo.py 的问题

  • TypeError: Cannot handle this data type: (1, 1, 64), |u :听说是因为Python版本问题:我的python=3.6.12没有问题。

4.跑BAM、PAM代码中常见的Out of memory 2 (64、256)GiB 问题

  • 首先:【Out of memory 2 GIB】主要是显存不够,很有效的做法就是减低 batch_size 8 –>4;
  • 其次:降batch size 8 为 4 之后,运行代码,实验跑1 个epoch后, 紧跟的val就会出现 【Out of memory 256GiB】, 因为val验证的代码没有将1024裁剪为256, 服务器的计算资源不够。需要 分别裁剪val 文件中的 A B label ,然后更改python train.py 后面的 val_data_path的路径到裁剪的val 文件夹(如 val_256)即可,代码后续放出。
  • 记得将后面测试的 test文件夹的图片也裁剪, 同样地,分别裁剪 A B label,不裁剪可能会 【Out of memory 64 GiB】。

三、裁剪代码实现


import os
import os.path as osp
import sys
from multiprocessing import Pool
import numpy as np
import cv2
from PIL import Image
import time
from shutil import get_terminal_size

sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))

def main():
    mode = 'pair'
    opt = {}
    opt['n_thread'] = 20
    opt['compression_level'] = 3

    if mode == 'single':
        opt['input_folder'] = './data/DIV2K/DIV2K_train_HR'
        opt['save_folder'] = './data/DIV2K/DIV2K800_sub'
        opt['crop_sz'] = 480
        opt['step'] = 240
        opt['thres_sz'] = 48
        extract_signle(opt)

    elif mode == 'pair':
        GT_folder = '/home/cug210/data/Lover/code/STANet-master/LEVIR-CD/test/B'
        save_GT_folder = '/home/cug210/data/Lover/code/STANet-master/LEVIR-CD/test_256/B'
        crop_sz = 256
        step = 256
        thres_sz = 256

        img_GT_list = _get_paths_from_images(GT_folder)

        print('process GT...')
        opt['input_folder'] = GT_folder
        opt['save_folder'] = save_GT_folder
        opt['crop_sz'] = crop_sz
        opt['step'] = step
        opt['thres_sz'] = thres_sz
        extract_signle(opt)

    else:
        raise ValueError('Wrong mode.')

def extract_signle(opt):
    input_folder = opt['input_folder']
    save_folder = opt['save_folder']
    if not osp.exists(save_folder):
        os.makedirs(save_folder)
        print('mkdir [{:s}] ...'.format(save_folder))
    else:
        print('Folder [{:s}] already exists. Exit...'.format(save_folder))
        sys.exit(1)
    img_list = _get_paths_from_images(input_folder)

    def update(arg):
        pbar.update(arg)

    pbar = ProgressBar(len(img_list))

    pool = Pool(opt['n_thread'])
    for path in img_list:
        pool.apply_async(worker, args=(path, opt), callback=update)
    pool.close()
    pool.join()
    print('All subprocesses done.')

def worker(path, opt):
    crop_sz = opt['crop_sz']
    step = opt['step']
    thres_sz = opt['thres_sz']
    img_name = osp.basename(path)
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)

    n_channels = len(img.shape)
    if n_channels == 2:
        h, w = img.shape
    elif n_channels == 3:
        h, w, c = img.shape
    else:
        raise ValueError('Wrong image shape - {}'.format(n_channels))

    h_space = np.arange(0, h - crop_sz + 1, step)
    if h - (h_space[-1] + crop_sz) > thres_sz:
        h_space = np.append(h_space, h - crop_sz)
    w_space = np.arange(0, w - crop_sz + 1, step)
    if w - (w_space[-1] + crop_sz) > thres_sz:
        w_space = np.append(w_space, w - crop_sz)

    index = 0
    for x in h_space:
        for y in w_space:
            index += 1
            if n_channels == 2:
                crop_img = img[x:x + crop_sz, y:y + crop_sz]
            else:
                crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
            crop_img = np.ascontiguousarray(crop_img)
            cv2.imwrite(
                osp.join(opt['save_folder'],
                         img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img,
                [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
    return 'Processing {:s} ...'.format(img_name)

class ProgressBar(object):
    '''A progress bar which can print the progress
    modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
    '''

    def __init__(self, task_num=0, bar_width=50, start=True):
        self.task_num = task_num
        max_bar_width = self._get_max_bar_width()
        self.bar_width = (bar_width if bar_width  max_bar_width else max_bar_width)
        self.completed = 0
        if start:
            self.start()

    def _get_max_bar_width(self):
        terminal_width, _ = get_terminal_size()
        max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
        if max_bar_width < 10:
            print('terminal width is too small ({}), please consider widen the terminal for better '
                  'progressbar visualization'.format(terminal_width))
            max_bar_width = 10
        return max_bar_width

    def start(self):
        if self.task_num > 0:
            sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
                ' ' * self.bar_width, self.task_num, 'Start...'))
        else:
            sys.stdout.write('completed: 0, elapsed: 0s')
        sys.stdout.flush()
        self.start_time = time.time()

    def update(self, msg='In progress...'):
        self.completed += 1
        elapsed = time.time() - self.start_time + 1e-9
        fps = self.completed / elapsed
        if self.task_num > 0:
            percentage = self.completed / float(self.task_num)
            eta = int(elapsed * (1 - percentage) / percentage + 0.5)
            mark_width = int(self.bar_width * percentage)
            bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
            sys.stdout.write('\033[2F')
            sys.stdout.write('\033[J')
            sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
                bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
        else:
            sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
                self.completed, int(elapsed + 0.5), fps))
        sys.stdout.flush()

IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def _get_paths_from_images(path):
    """get image path list from image folder"""
    assert osp.isdir(path), '{:s} is not a valid directory'.format(path)
    images = []
    for dirpath, _, fnames in sorted(os.walk(path)):
        for fname in sorted(fnames):
            print("..fname is:",fname)

            if is_image_file(fname):
                img_path = os.path.join(dirpath, fname)
                images.append(img_path)
    assert images, '{:s} has no valid image file'.format(path)
    return images

if __name__ == '__main__':
    main()

说明:

  • 只需要更改32-36行的信息:
  • 【原始文件夹路径】
  • 【保存的裁剪后图片的文件夹路径】
  • 【裁剪尺寸crop_size、位移尺寸step(两者相等,表示下一张图和第一张图没有重叠)】
  • 【阈值(thres_sz)设置为256,表示裁剪到最后,剩下不到256的残缺,就不裁剪了。】

Original: https://blog.csdn.net/qq_45041702/article/details/121851147
Author: 未知量0520
Title: STANet_pytorch代码问题汇总、附上裁剪图片代码(有问留言必答)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/705619/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球