RandLA-Net源码解析

前言

论文:https://arxiv.org/abs/1911.11236
github:https://github.com/QingyongHu/RandLA-Net

  • 本文章以运行S3DIS数据集为例.

  • 本文章不是运行指南,仅为阅读源码的笔记。

数据准备

下载数据集解压到 /data/S3DIS/Stanford3dDataset_v1.2_Aligned_Version文件夹中。
Stanford3dDataset_v1.2_Aligned_Version文件夹中包含多个 Area_*(区域)文件夹。
每个 Area(区域)文件夹中包含多个 场景文件夹。
每个场景文件夹(以 Area_1\conferenceRoom_1为例)中主要包含一个场景点云文件 conferenceRoom_1.txt和一个分类点云文件夹 Annotations
conferenceRoom_1.txt中直接以 XYZRGB格式储存了整个场景的点云。

-15.609 39.505 2.214 71 64 54
-15.634 39.518 2.198 68 64 52
-15.622 39.514 2.195 70 61 52
-15.621 39.510 2.215 72 65 55
-15.606 39.505 2.211 71 63 52
-15.657 39.524 2.213 76 70 58
-15.549 39.484 2.206 63 53 44

Annotations文件夹中的文件按照不同分类储存了点云,文件名为分类名称,文件内容同样为 XYZRGB格式的点云。

RandLA-Net源码解析

data_prepare_s3dis.py

在准备好数据集之后运行 data_prepare_s3dis.py进行数据预处理。
预处理会在 /data/S3DIS文件夹下生成两个额外的文件夹 original_plyinput_0.040

程序首先将 Stanford3dDataset_v1.2_Aligned_Version下的每个场景处理成一个 区域名_场景名.ply的文件,放入 original_ply文件夹中。其中对 XYZ进行了平移,使得 X, Y, Z均大于等于 0,数据以 XYZRGBL(L为label,分类编码的索引)方式储存。

随后对每个区域的每个场景进行如下操作:

  • 进行网格下采样,将颜色进行归一化(RGB/255),保存在 input_0.040/区域名_场景名.ply
  • 对下采样后的点云坐标构建KD树,将KD树保存在 input_0.040/区域名_场景名_KDTree.pkl
  • 使用KD树查询每个 原始点最邻近的 下采样点index,将查询到的等同 原始点数量index列表保存在 input_0.040/区域名_场景名_proj.pkl

注意:以下代码为在windows下运行做过部分修改,仅供参考。

from sklearn.neighbors import KDTree
from os.path import join, exists, dirname, abspath
import numpy as np
import pandas as pd
import os, sys, glob, pickle

BASE_DIR = dirname(abspath(__file__))
ROOT_DIR = dirname(BASE_DIR)
sys.path.append(BASE_DIR)
sys.path.append(ROOT_DIR)
from helper_ply import write_ply
from helper_tool import DataProcessing as DP

dataset_path = join(ROOT_DIR, 'data/S3DIS/Stanford3dDataset_v1.2_Aligned_Version')
anno_paths = [line.rstrip() for line in open(join(BASE_DIR, 'meta/anno_paths.txt'))]
anno_paths = [join(dataset_path, p) for p in anno_paths]

gt_class = [x.rstrip() for x in open(join(BASE_DIR, 'meta/class_names.txt'))]
gt_class2label = {cls: i for i, cls in enumerate(gt_class)}

sub_grid_size = 0.04
original_pc_folder = join(dirname(dataset_path), 'original_ply')
sub_pc_folder = join(dirname(dataset_path), 'input_{:.3f}'.format(sub_grid_size))
os.mkdir(original_pc_folder) if not exists(original_pc_folder) else None
os.mkdir(sub_pc_folder) if not exists(sub_pc_folder) else None
out_format = '.ply'

def convert_pc2ply(anno_path, save_path):
"""
    将每个场景压成一个文件,位于data/S3DIS/original_ply中
    Convert original dataset files to ply file (each line is XYZRGBL).L:label
    We aggregated all the points from each instance in the room.

    :param anno_path: path to annotations. e.g. Area_1/office_2/Annotations/
    :param save_path: path to save original point clouds (each line is XYZRGBL)
    :return: None
"""
    data_list = []

    for f in glob.glob(join(anno_path, '*.txt')):
        class_name = os.path.basename(f).split('_')[0]
        if class_name not in gt_class:
            class_name = 'clutter'
        pc = pd.read_csv(f, header=None, delim_whitespace=True).values
        labels = np.ones((pc.shape[0], 1)) * gt_class2label[class_name]
        data_list.append(np.concatenate([pc, labels], 1))

    pc_label = np.concatenate(data_list, 0)
    xyz_min = np.amin(pc_label, axis=0)[0:3]
    pc_label[:, 0:3] -= xyz_min

    xyz = pc_label[:, :3].astype(np.float32)
    colors = pc_label[:, 3:6].astype(np.uint8)
    labels = pc_label[:, 6].astype(np.uint8)
    write_ply(save_path, (xyz, colors, labels), ['x', 'y', 'z', 'red', 'green', 'blue', 'class'])

    sub_xyz, sub_colors, sub_labels = DP.grid_sub_sampling(xyz, colors, labels, sub_grid_size)
    sub_colors = sub_colors / 255.0
    sub_ply_file = join(sub_pc_folder, save_path.replace('/', '\\').split('\\')[-1][:-4] + '.ply')
    write_ply(sub_ply_file, [sub_xyz, sub_colors, sub_labels], ['x', 'y', 'z', 'red', 'green', 'blue', 'class'])

    search_tree = KDTree(sub_xyz)
    kd_tree_file = join(sub_pc_folder, str(save_path.replace('/', '\\').split('\\')[-1][:-4]) + '_KDTree.pkl')
    with open(kd_tree_file, 'wb') as f:
        pickle.dump(search_tree, f)

    proj_idx = np.squeeze(search_tree.query(xyz, return_distance=False))
    proj_idx = proj_idx.astype(np.int32)
    proj_save = join(sub_pc_folder, str(save_path.replace('/', '\\').split('\\')[-1][:-4]) + '_proj.pkl')
    with open(proj_save, 'wb') as f:
        pickle.dump([proj_idx, labels], f)

if __name__ == '__main__':

    for annotation_path in anno_paths:
        print(annotation_path)
        elements = str(annotation_path).replace('/', '\\').split('\\')
        out_file_name = elements[-3] + '_' + elements[-2] + out_format
        convert_pc2ply(annotation_path, join(original_pc_folder, out_file_name))

main_S3DIS.py

在完成数据预处理后运行 main_S3DIS.py,传入gpu编号(默认为0),测试区域(默认为5:即使用Area_5进行测试,其他区域进行训练),模式(train, test, vis),指定预训练模型(仅mode=test时有效)。
例如:

python -B main_S3DIS.py --gpu 0 --mode train --test_area 1

加载数据

程序首先加载数据集,按照Area分为训练组和验证组,整理如下数据。


        self.val_proj = []
        self.val_labels = []
        self.possibility = {}
        self.min_possibility = {}
        self.input_trees = {'training': [], 'validation': []}
        self.input_colors = {'training': [], 'validation': []}
        self.input_labels = {'training': [], 'validation': []}
        self.input_names = {'training': [], 'validation': []}
        self.load_sub_sampled_clouds(cfg.sub_grid_size)
    def load_sub_sampled_clouds(self, sub_grid_size):
        tree_path = join(self.path, 'input_{:.3f}'.format(sub_grid_size))
        for i, file_path in enumerate(self.all_files):
            t0 = time.time()
            cloud_name = file_path.replace('/', '\\').split('\\')[-1][:-4]
            if self.val_split in cloud_name:
                cloud_split = 'validation'
            else:
                cloud_split = 'training'

            kd_tree_file = join(tree_path, '{:s}_KDTree.pkl'.format(cloud_name))
            sub_ply_file = join(tree_path, '{:s}.ply'.format(cloud_name))

            data = read_ply(sub_ply_file)
            sub_colors = np.vstack((data['red'], data['green'], data['blue'])).T
            sub_labels = data['class']

            with open(kd_tree_file, 'rb') as f:
                search_tree = pickle.load(f)

            self.input_trees[cloud_split] += [search_tree]
            self.input_colors[cloud_split] += [sub_colors]
            self.input_labels[cloud_split] += [sub_labels]
            self.input_names[cloud_split] += [cloud_name]

            size = sub_colors.shape[0] * 4 * 7
            print('{:s} {:.1f} MB loaded in {:.1f}s'.format(kd_tree_file.replace('/', '\\').split('\\')[-1], size * 1e-6, time.time() - t0))

        print('\nPreparing reprojected indices for testing')

        for i, file_path in enumerate(self.all_files):
            t0 = time.time()
            cloud_name = file_path.replace('/', '\\').split('\\')[-1][:-4]

            if self.val_split in cloud_name:
                proj_file = join(tree_path, '{:s}_proj.pkl'.format(cloud_name))
                with open(proj_file, 'rb') as f:
                    proj_idx, labels = pickle.load(f)
                self.val_proj += [proj_idx]
                self.val_labels += [labels]
                print('{:s} done in {:.1f}s'.format(cloud_name, time.time() - t0))

初始化输入信息流

主要包括:构建生成器、使用生成器构建数据集、设置MAP函数预处理。

    def init_input_pipeline(self):
        print('Initiating input pipelines')
        cfg.ignored_label_inds = [self.label_to_idx[ign_label] for ign_label in self.ignored_labels]
        gen_function, gen_types, gen_shapes = self.get_batch_gen('training')
        gen_function_val, _, _ = self.get_batch_gen('validation')
        self.train_data = tf.data.Dataset.from_generator(gen_function, gen_types, gen_shapes)
        self.val_data = tf.data.Dataset.from_generator(gen_function_val, gen_types, gen_shapes)

        self.batch_train_data = self.train_data.batch(cfg.batch_size)
        self.batch_val_data = self.val_data.batch(cfg.val_batch_size)
        map_func = self.get_tf_mapping2()

        self.batch_train_data = self.batch_train_data.map(map_func=map_func)
        self.batch_val_data = self.batch_val_data.map(map_func=map_func)

        self.batch_train_data = self.batch_train_data.prefetch(cfg.batch_size)
        self.batch_val_data = self.batch_val_data.prefetch(cfg.val_batch_size)

        iter = tf.data.Iterator.from_structure(self.batch_train_data.output_types, self.batch_train_data.output_shapes)
        self.flat_inputs = iter.get_next()
        self.train_init_op = iter.make_initializer(self.batch_train_data)
        self.val_init_op = iter.make_initializer(self.batch_val_data)

构建数据生成器

生成器初始化时执行以下操作:

  1. 为每个点云生成一个” 概率“。(这里的 概率使用黑体加粗,因为该 概率和我们平时常用的概率意义不同)
  2. 确定每个场景中点云的最小 概率

生成器提供以下数据处理操作:

  1. 选择 存在最低 概率 的场景最低 概率 的点 ,称为 中心点
  2. 中心点坐标添加噪声。
  3. 通过KD树查找 中心点附近一定数量(超参数num_points = 40960)的点(包括原中心点)。若该场景点数量不足,则查找该场景中所有点,这些查找到的点称为 被选点
  4. 将查找到的 被选点顺序打乱。
  5. 获取 被选点的坐标,颜色,标签。
  6. 通过 被选点坐标 - 中心点坐标被选点中心归零。
  7. 增加被选点的 概率,越接近中心增加的数值越高,以降低这些点再次被选中的概率。
  8. 若在 步骤3中未查找到足够的点,在 被选点中随机重复采点,使其达到参数要求的数量。
  9. 输出:XYZ,RGB,Label, 被选点的index,场景的index。

    def get_batch_gen(self, split):
        if split == 'training':
            num_per_epoch = cfg.train_steps * cfg.batch_size
        elif split == 'validation':
            num_per_epoch = cfg.val_steps * cfg.val_batch_size

        self.possibility[split] = []
        self.min_possibility[split] = []

        for i, tree in enumerate(self.input_colors[split]):
            self.possibility[split] += [np.random.rand(tree.data.shape[0]) * 1e-3]
            self.min_possibility[split] += [float(np.min(self.possibility[split][-1]))]

        def spatially_regular_gen():

            for i in range(num_per_epoch):

                cloud_idx = int(np.argmin(self.min_possibility[split]))

                point_ind = np.argmin(self.possibility[split][cloud_idx])

                points = np.array(self.input_trees[split][cloud_idx].data, copy=False)

                center_point = points[point_ind, :].reshape(1, -1)

                noise = np.random.normal(scale=cfg.noise_init / 10, size=center_point.shape)
                pick_point = center_point + noise.astype(center_point.dtype)

                if len(points) < cfg.num_points:

                    queried_idx = self.input_trees[split][cloud_idx].query(pick_point, k=len(points))[1][0]
                else:

                    queried_idx = self.input_trees[split][cloud_idx].query(pick_point, k=cfg.num_points)[1][0]

                queried_idx = DP.shuffle_idx(queried_idx)

                queried_pc_xyz = points[queried_idx]
                queried_pc_xyz = queried_pc_xyz - pick_point
                queried_pc_colors = self.input_colors[split][cloud_idx][queried_idx]
                queried_pc_labels = self.input_labels[split][cloud_idx][queried_idx]

                dists = np.sum(np.square((points[queried_idx] - pick_point).astype(np.float32)), axis=1)
                delta = np.square(1 - dists / np.max(dists))
                self.possibility[split][cloud_idx][queried_idx] += delta
                self.min_possibility[split][cloud_idx] = float(np.min(self.possibility[split][cloud_idx]))

                if len(points) < cfg.num_points:
                    queried_pc_xyz, queried_pc_colors, queried_idx, queried_pc_labels = \
                        DP.data_aug(queried_pc_xyz, queried_pc_colors, queried_pc_labels, queried_idx, cfg.num_points)

                if True:
                    yield (queried_pc_xyz.astype(np.float32),
                           queried_pc_colors.astype(np.float32),
                           queried_pc_labels,
                           queried_idx.astype(np.int32),
                           np.array([cloud_idx], dtype=np.int32))

        gen_func = spatially_regular_gen
        gen_types = (tf.float32, tf.float32, tf.int32, tf.int32, tf.int32)
        gen_shapes = ([None, 3], [None, 3], [None], [None], [None])
        return gen_func, gen_types, gen_shapes

MAP(数据预处理)

对生成器生成的数据进行一系列预处理:

  1. 将生成器生成的点称为 &#x539F;&#x59CB;&#x70B9;
  2. 获取 &#x539F;&#x59CB;&#x70B9;中每个点的K临近(超参数k=16)的 &#x539F;&#x59CB;&#x70B9;的index,记为 &#x539F;&#x59CB;&#x70B9;knn16&#x539F;&#x59CB;&#x70B9;
  3. &#x539F;&#x59CB;&#x70B9;进行随机下采样,采样 1/sub_sampling_ratio的点,得到 &#x4E0B;&#x91C7;&#x6837;&#x70B9;
  4. &#x539F;&#x59CB;&#x70B9;knn16&#x539F;&#x59CB;&#x70B9;进行和3相同的下采样,得到 &#x4E0B;&#x91C7;&#x6837;&#x70B9;knn16&#x539F;&#x59CB;&#x70B9;
  5. 获取 &#x4E0B;&#x91C7;&#x6837;&#x70B9;中每个点的K=1临近的 &#x4E0B;&#x91C7;&#x6837;&#x70B9;,记为 &#x4E0B;&#x91C7;&#x6837;&#x70B9;knn1&#x4E0B;&#x91C7;&#x6837;&#x70B9;
  6. 使 &#x539F;&#x59CB;&#x70B9;=&#x4E0B;&#x91C7;&#x6837;&#x70B9;,回到 &#x6B65;&#x9AA4;2,直到完成规定次数(超参数 num_layers = 5)的下采样。

记录每次循环中的 &#x539F;&#x59CB;&#x70B9;&#x539F;&#x59CB;&#x70B9;knn16&#x539F;&#x59CB;&#x70B9;&#x4E0B;&#x91C7;&#x6837;&#x70B9;knn16&#x539F;&#x59CB;&#x70B9;&#x4E0B;&#x91C7;&#x6837;&#x70B9;knn1&#x4E0B;&#x91C7;&#x6837;&#x70B9;
返回 记录的这些数据 和 生成器输出的 RGB,Label, &#x88AB;&#x9009;&#x70B9;的index,场景的index。

    @staticmethod
    def get_tf_mapping2():

        def tf_map(batch_xyz, batch_features, batch_labels, batch_pc_idx, batch_cloud_idx):
            batch_features = tf.concat([batch_xyz, batch_features], axis=-1)
            input_points = []
            input_neighbors = []
            input_pools = []
            input_up_samples = []

            for i in range(cfg.num_layers):
                neighbour_idx = tf.py_func(DP.knn_search, [batch_xyz, batch_xyz, cfg.k_n], tf.int32)
                sub_points = batch_xyz[:, :tf.shape(batch_xyz)[1] // cfg.sub_sampling_ratio[i], :]
                pool_i = neighbour_idx[:, :tf.shape(batch_xyz)[1] // cfg.sub_sampling_ratio[i], :]
                up_i = tf.py_func(DP.knn_search, [sub_points, batch_xyz, 1], tf.int32)
                input_points.append(batch_xyz)
                input_neighbors.append(neighbour_idx)
                input_pools.append(pool_i)
                input_up_samples.append(up_i)
                batch_xyz = sub_points

            input_list = input_points + input_neighbors + input_pools + input_up_samples
            input_list += [batch_features, batch_labels, batch_pc_idx, batch_cloud_idx]

            return input_list

        return tf_map

RandLANet.py

该文件储存网络结构,供main_S3DIS.py调用。
首先我们看到模型快照会储存在 results文件夹下

flat_inputs = dataset.flat_inputs
self.config = config

if self.config.saving:
    if self.config.saving_path is None:
        self.saving_path = time.strftime('results/Log_%Y-%m-%d_%H-%M-%S', time.gmtime())
    else:
        self.saving_path = self.config.saving_path
    makedirs(self.saving_path) if not exists(self.saving_path) else None

其次开始构建网络的输入和各种参数


with tf.variable_scope('inputs'):
    self.inputs = dict()
    num_layers = self.config.num_layers
    self.inputs['xyz'] = flat_inputs[:num_layers]
    self.inputs['neigh_idx'] = flat_inputs[num_layers: 2 * num_layers]
    self.inputs['sub_idx'] = flat_inputs[2 * num_layers:3 * num_layers]
    self.inputs['interp_idx'] = flat_inputs[3 * num_layers:4 * num_layers]
    self.inputs['features'] = flat_inputs[4 * num_layers]
    self.inputs['labels'] = flat_inputs[4 * num_layers + 1]
    self.inputs['input_inds'] = flat_inputs[4 * num_layers + 2]
    self.inputs['cloud_inds'] = flat_inputs[4 * num_layers + 3]
    self.labels = self.inputs['labels']
    self.is_training = tf.placeholder(tf.bool, shape=())
    self.training_step = 1
    self.training_epoch = 0
    self.correct_prediction = 0
    self.accuracy = 0
    self.mIou_list = [0]
    self.class_weights = DP.get_class_weights(dataset.name)
    self.Log_file = open('log_train_' + dataset.name + str(dataset.val_split) + '.txt', 'a')

网络核心结构

with tf.variable_scope('layers'):
    self.logits = self.inference(self.inputs, self.is_training)

在这个函数 inference()中,就是网络的核心结构
首先使用一个全连接层将特征数据变换为 B*N*1*8
随后叠加上多层(超参数 num_layers=5)的 dilated_res_block(),每层 dilated_res_block()之后还进行 random_sample()

Encoder

def inference(self, inputs, is_training):
    d_out = self.config.d_out
    feature = inputs['features']
    feature = tf.layers.dense(feature, 8, activation=None, name='fc0')
    feature = tf.nn.leaky_relu(tf.layers.batch_normalization(feature, -1, 0.99, 1e-6, training=is_training))
    feature = tf.expand_dims(feature, axis=2)

    f_encoder_list = []
    for i in range(self.config.num_layers):
        f_encoder_i = self.dilated_res_block(feature, inputs['xyz'][i], inputs['neigh_idx'][i], d_out[i],
                                             'Encoder_layer_' + str(i), is_training)
        f_sampled_i = self.random_sample(f_encoder_i, inputs['sub_idx'][i])
        feature = f_sampled_i
        if i == 0:
            f_encoder_list.append(f_encoder_i)
        f_encoder_list.append(f_sampled_i)

    feature = helper_tf_util.conv2d(f_encoder_list[-1], f_encoder_list[-1].get_shape()[3].value, [1, 1],
                                    'decoder_0',
                                    [1, 1], 'VALID', True, is_training)

    ...

dilated_res_block

RandLA-Net源码解析
这是作者论文中给出的 dilated_res_block流程图,该流程图省略了部分细节,结合下面的 building_block流程图能更加清晰的看出网络结构。(默认相关点数量 k_n = 16
RandLA-Net源码解析
RandLA-Net源码解析
    def dilated_res_block(self, feature, xyz, neigh_idx, d_out, name, is_training):
        f_pc = helper_tf_util.conv2d(feature, d_out // 2, [1, 1], name + 'mlp1', [1, 1], 'VALID', True, is_training)
        f_pc = self.building_block(xyz, f_pc, neigh_idx, d_out, name + 'LFA', is_training)
        f_pc = helper_tf_util.conv2d(f_pc, d_out * 2, [1, 1], name + 'mlp2', [1, 1], 'VALID', True, is_training,
                                     activation_fn=None)
        shortcut = helper_tf_util.conv2d(feature, d_out * 2, [1, 1], name + 'shortcut', [1, 1], 'VALID',
                                         activation_fn=None, bn=True, is_training=is_training)
        return tf.nn.leaky_relu(f_pc + shortcut)

    def building_block(self, xyz, feature, neigh_idx, d_out, name, is_training):
        d_in = feature.get_shape()[-1].value
        f_xyz = self.relative_pos_encoding(xyz, neigh_idx)
        f_xyz = helper_tf_util.conv2d(f_xyz, d_in, [1, 1], name + 'mlp1', [1, 1], 'VALID', True, is_training)
        f_neighbours = self.gather_neighbour(tf.squeeze(feature, axis=2), neigh_idx)
        f_concat = tf.concat([f_neighbours, f_xyz], axis=-1)
        f_pc_agg = self.att_pooling(f_concat, d_out // 2, name + 'att_pooling_1', is_training)

        f_xyz = helper_tf_util.conv2d(f_xyz, d_out // 2, [1, 1], name + 'mlp2', [1, 1], 'VALID', True, is_training)
        f_neighbours = self.gather_neighbour(tf.squeeze(f_pc_agg, axis=2), neigh_idx)
        f_concat = tf.concat([f_neighbours, f_xyz], axis=-1)
        f_pc_agg = self.att_pooling(f_concat, d_out, name + 'att_pooling_2', is_training)
        return f_pc_agg

    def relative_pos_encoding(self, xyz, neigh_idx):
        neighbor_xyz = self.gather_neighbour(xyz, neigh_idx)
        xyz_tile = tf.tile(tf.expand_dims(xyz, axis=2), [1, 1, tf.shape(neigh_idx)[-1], 1])
        relative_xyz = xyz_tile - neighbor_xyz
        relative_dis = tf.sqrt(tf.reduce_sum(tf.square(relative_xyz), axis=-1, keepdims=True))
        relative_feature = tf.concat([relative_dis, relative_xyz, xyz_tile, neighbor_xyz], axis=-1)
        return relative_feature

    @staticmethod
    def gather_neighbour(pc, neighbor_idx):

        batch_size = tf.shape(pc)[0]
        num_points = tf.shape(pc)[1]
        d = pc.get_shape()[2].value
        index_input = tf.reshape(neighbor_idx, shape=[batch_size, -1])
        features = tf.batch_gather(pc, index_input)
        features = tf.reshape(features, [batch_size, num_points, tf.shape(neighbor_idx)[-1], d])
        return features

random_sample

random_sample的内容比较简单,毕竟在数据预处理的时候就把每层的随机下采样做了,现在只需要做些索引工作就行了。

    @staticmethod
    def random_sample(feature, pool_idx):
"""
        :param feature: [B, N, d] input features matrix
        :param pool_idx: [B, N', max_num] N' < N, N' is the selected position after pooling
        :return: pool_features = [B, N', d] pooled features matrix
"""
        feature = tf.squeeze(feature, axis=2)
        num_neigh = tf.shape(pool_idx)[-1]
        d = feature.get_shape()[-1]
        batch_size = tf.shape(pool_idx)[0]
        pool_idx = tf.reshape(pool_idx, [batch_size, -1])
        pool_features = tf.batch_gather(feature, pool_idx)
        pool_features = tf.reshape(pool_features, [batch_size, -1, num_neigh, d])
        pool_features = tf.reduce_max(pool_features, axis=2, keepdims=True)
        return pool_features

Decoder

首先进行和Encoder相反的操作:多层上采样和反卷积。
随后进行多次卷积处理(其中包含一次droupout),最终得到 B*N*class的点云分类矩阵。


    f_decoder_list = []
    for j in range(self.config.num_layers):
        f_interp_i = self.nearest_interpolation(feature, inputs['interp_idx'][-j - 1])
        f_decoder_i = helper_tf_util.conv2d_transpose(tf.concat([f_encoder_list[-j - 2], f_interp_i], axis=3),
                                                      f_encoder_list[-j - 2].get_shape()[-1].value, [1, 1],
                                                      'Decoder_layer_' + str(j), [1, 1], 'VALID', bn=True,
                                                      is_training=is_training)
        feature = f_decoder_i
        f_decoder_list.append(f_decoder_i)

    f_layer_fc1 = helper_tf_util.conv2d(f_decoder_list[-1], 64, [1, 1], 'fc1', [1, 1], 'VALID', True, is_training)
    f_layer_fc2 = helper_tf_util.conv2d(f_layer_fc1, 32, [1, 1], 'fc2', [1, 1], 'VALID', True, is_training)
    f_layer_drop = helper_tf_util.dropout(f_layer_fc2, keep_prob=0.5, is_training=is_training, scope='dp1')
    f_layer_fc3 = helper_tf_util.conv2d(f_layer_drop, self.config.num_classes, [1, 1], 'fc', [1, 1], 'VALID', False,
                                        is_training, activation_fn=None)
    f_out = tf.squeeze(f_layer_fc3, [2])
    return f_out

loss

如果指定了要忽略的类型,在开始计算loss前要处理这些点。


with tf.variable_scope('loss'):
    self.logits = tf.reshape(self.logits, [-1, config.num_classes])
    self.labels = tf.reshape(self.labels, [-1])

    ignored_bool = tf.zeros_like(self.labels, dtype=tf.bool)
    for ign_label in self.config.ignored_label_inds:
        ignored_bool = tf.logical_or(ignored_bool, tf.equal(self.labels, ign_label))

    valid_idx = tf.squeeze(tf.where(tf.logical_not(ignored_bool)))
    valid_logits = tf.gather(self.logits, valid_idx, axis=0)
    valid_labels_init = tf.gather(self.labels, valid_idx, axis=0)

    reducing_list = tf.range(self.config.num_classes, dtype=tf.int32)
    inserted_value = tf.zeros((1,), dtype=tf.int32)
    for ign_label in self.config.ignored_label_inds:
        reducing_list = tf.concat([reducing_list[:ign_label], inserted_value, reducing_list[ign_label:]], 0)
    valid_labels = tf.gather(reducing_list, valid_labels_init)
    self.loss = self.get_loss(valid_logits, valid_labels, self.class_weights)

这里的 pre_cal_weights是预计算出的每个类别的点数,将其作为反权重乘以对应分类的 loss可以防止 “当点云中某个类别特别多,神经网络就只选择预测那个类别以达到一个局部最优解” 的现象。

    def get_loss(self, logits, labels, pre_cal_weights):

        class_weights = tf.convert_to_tensor(pre_cal_weights, dtype=tf.float32)
        one_hot_labels = tf.one_hot(labels, depth=self.config.num_classes)
        weights = tf.reduce_sum(class_weights * one_hot_labels, axis=1)
        unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=one_hot_labels)
        weighted_losses = unweighted_losses * weights
        output_loss = tf.reduce_mean(weighted_losses)
        return output_loss

优化器配置

最后是配置优化器,各种显示参数。

        with tf.variable_scope('optimizer'):
            self.learning_rate = tf.Variable(config.learning_rate, trainable=False, name='learning_rate')
            self.train_op = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss)
            self.extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        with tf.variable_scope('results'):
            self.correct_prediction = tf.nn.in_top_k(valid_logits, valid_labels, 1)
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
            self.prob_logits = tf.nn.softmax(self.logits)

            tf.summary.scalar('learning_rate', self.learning_rate)
            tf.summary.scalar('loss', self.loss)
            tf.summary.scalar('accuracy', self.accuracy)

        my_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        self.saver = tf.train.Saver(my_vars, max_to_keep=100)
        c_proto = tf.ConfigProto()
        c_proto.gpu_options.allow_growth = True
        self.sess = tf.Session(config=c_proto)
        self.merged = tf.summary.merge_all()
        self.train_writer = tf.summary.FileWriter(config.train_sum_dir, self.sess.graph)
        self.sess.run(tf.global_variables_initializer())

train

    def train(self, dataset):
        log_out('****EPOCH {}****'.format(self.training_epoch), self.Log_file)
        self.sess.run(dataset.train_init_op)
        while self.training_epoch < self.config.max_epoch:
            t_start = time.time()
            try:
                ops = [self.train_op,
                       self.extra_update_ops,
                       self.merged,
                       self.loss,
                       self.logits,
                       self.labels,
                       self.accuracy]
                _, _, summary, l_out, probs, labels, acc = self.sess.run(ops, {self.is_training: True})
                self.train_writer.add_summary(summary, self.training_step)
                t_end = time.time()
                if self.training_step % 50 == 0:
                    message = 'Step {:08d} L_out={:5.3f} Acc={:4.2f} ''---{:8.2f} ms/batch'
                    log_out(message.format(self.training_step, l_out, acc, 1000 * (t_end - t_start)), self.Log_file)
                self.training_step += 1

            except tf.errors.OutOfRangeError:

                m_iou = self.evaluate(dataset)
                if m_iou > np.max(self.mIou_list):

                    snapshot_directory = join(self.saving_path, 'snapshots')
                    makedirs(snapshot_directory) if not exists(snapshot_directory) else None
                    self.saver.save(self.sess, snapshot_directory + './snap', global_step=self.training_step)
                self.mIou_list.append(m_iou)
                log_out('Best m_IoU is: {:5.3f}'.format(max(self.mIou_list)), self.Log_file)

                self.training_epoch += 1
                self.sess.run(dataset.train_init_op)

                op = self.learning_rate.assign(tf.multiply(self.learning_rate,
                                                           self.config.lr_decays[self.training_epoch]))
                self.sess.run(op)
                log_out('****EPOCH {}****'.format(self.training_epoch), self.Log_file)

            except tf.errors.InvalidArgumentError as e:

                print('Caught a NaN error :')
                print(e.error_code)
                print(e.message)
                print(e.op)
                print(e.op.name)
                print([t.name for t in e.op.inputs])
                print([t.name for t in e.op.outputs])

                a = 1 / 0

        print('finished')
        self.sess.close()

Original: https://blog.csdn.net/DLW__/article/details/114281940
Author: 3D_DLW
Title: RandLA-Net源码解析

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/666437/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球