NeRF 源码分析解读(一)

NeRF 源码解读(一)

前言

NeRF 是三维视觉中新视图合成任务的启示性工作,最近领域内出现了许多基于 NeRF 的变种工作。本文以pytorch 版 NeRF 作为基础对 NeRF 的代码进行分析。
主要从以下方面开展:

  1. 数据的加载
  2. 光线的生成
  3. NeRF 网络架构
  4. 渲染过程

一、数据的加载

本文以加载合成数据集中 lego 图像为例。
首先我们观察 ./data/nerf_synthetic/lego 文件夹下的树结构:

NeRF 源码分析解读(一)
train、test、val 三个文件夹下包含了训练要用到的 .png 图像,每个文件夹下包含 100 个文件。.json 文件包含了相机的 camera2word 转置矩阵,下图展示了部分文件中的内容。关于此转置矩阵不再展开叙述,具体知识可查看 SLAM 14 讲。了解以上基本信息后解析数据加载的代码。
NeRF 源码分析解读(一)
frame 的值是一个列表,其中列表中的值是字典
def train():

    parser = config_parser()
    args = parser.parse_args()

    ...

    elif args.dataset_type == 'blender':
        images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
        print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split

        near = 2.
        far = 6.

        if args.white_bkgd:

            images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
        else:
            images = images[...,:3]

我们通过 load_blender_data() 函数得到了指定文件夹下的所有图像、pose、测试渲染的pose、宽高焦距以及分割数组。下面对数据加载函数进行分析。

def load_blender_data(basedir, half_res=False, testskip=1):
"""
    :param basedir: 数据文件夹路径
    :param half_res: 是否对图像进行半裁剪
    :param testskip: 挑选测试数据集的跳跃步长
    :return:
"""
    splits = ['train', 'val', 'test']
    metas = {}
    for s in splits:

        with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
            metas[s] = json.load(fp)

    all_imgs = []
    all_poses = []
    counts = [0]

    for s in splits:
        meta = metas[s]
        imgs = []
        poses = []

        if s=='train' or testskip==0:
            skip = 1
        else:
            skip = testskip

        for frame in meta['frames'][::skip]:

            fname = os.path.join(basedir, frame['file_path'] + '.png')

            imgs.append(imageio.imread(fname))
            poses.append(np.array(frame['transform_matrix']))
        imgs = (np.array(imgs) / 255.).astype(np.float32)
        poses = np.array(poses).astype(np.float32)
        counts.append(counts[-1] + imgs.shape[0])

        all_imgs.append(imgs)
        all_poses.append(poses)

    i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]

    imgs = np.concatenate(all_imgs, 0)
    poses = np.concatenate(all_poses, 0)

    H, W = imgs[0].shape[:2]
    camera_angle_x = float(meta['camera_angle_x'])
    focal = .5 * W / np.tan(.5 * camera_angle_x)

    render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)

    if half_res:
        H = H//2
        W = W//2
        focal = focal/2.

        imgs_half_res = np.zeros((imgs.shape[0], H, W, 4))
        for i, img in enumerate(imgs):
            imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
        imgs = imgs_half_res

    return imgs, poses, render_poses, [H, W, focal], i_split

通过对以上代码的分析,我们可以得到以下结果:

imgs : 根据 .json 文件加载到的所有图像数据。(N,H,W,4)N 代表用于 train、test、val 的总数量
poses : 转置矩阵。(N,4,4)
render_poses : 用于测试的 pose 。(40,4,4)
i_split : [[0:train], [train:val], [val:test]]

完成数据加载以后,就可以根据 image 数据模拟生成光线。具体代码解析见下一节:
NeRF源码分析解读(二)

Original: https://blog.csdn.net/qq_41071191/article/details/125440451
Author: 面里多加汤
Title: NeRF 源码分析解读(一)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/624294/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球