4、nerf(pytorch)

简介

4、nerf(pytorch)
2020年ECCV最佳论文Neural Radiance Field (NeRF) 将三维场景的隐表示方法推向了新的高度,其本质在于利用神经网络通过多视角2D图像进行3D场景重建,并进行渲染合成新视角的2D图像

视角合成方法通常使用一个中间3D场景表征作为中介来生成高质量的虚拟视角,如何对这个中间3D场景进行表征,分为了”显示表示”和”隐式表示”,然后再对这个中间3D场景进行渲染,生成照片级的视角。

“显示表示”3D场景包括Mesh,Point Cloud,Voxel,Volume等,它能够对场景进行显式建模,但是因为其是离散表示的,导致了不够精细化会造成重叠等伪影,更重要的是,它存储的三维场景表达信息数据量极大,对内存的消耗限制了高分辨率场景的应用。

“隐式表示”3D场景通常用一个函数来描述场景几何,可以理解为将复杂的三维场景表达信息存储在函数的参数中。因为往往是学习一种3D场景的描述函数,因此在表达大分辨率场景的时候它的参数量相对于”显示表示”是较少的,并且”隐式表示”函数是种连续化的表达,对于场景的表达会更为精细。

NeRF做到了利用”隐式表示”实现了照片级的视角合成效果,它选择了Volume作为中间3D场景表征,然后再通过Volume rendering实现了特定视角照片合成效果。可以说NeRF实现了从离散的照片集中学习出了一种隐式的Volume表达,然后在某个特定视角,利用该隐式Volume表达和体渲染得到该视角下的照片。

4、nerf(pytorch)

; 整体流程

数据准备

首先需要准备多张从不同角度拍摄的同场景静态2D图像

还需要各个像素点的位置(x,y,z)和方位视角(θ,φ)作为输入进行重建。为了获得这些数据,NeRF中采用了传统方法COLMAP进行参数估计。通过COLMAP可以得到场景的稀疏重建结果,其输出文件包括相机内参,相机外参和3D点的信息。

然后进一步利用LLFF开源代码中的imgs2poses文件将内外参整合到一个文件poses_boudns.npy中,该文件记录了相机的内参,包括图片分辨率(图片高与宽度)、焦距,共3个维度、外参(包括相机坐标到世界坐标转换的平移矩阵t与旋转矩阵r,其中旋转矩阵为 3×3 的矩阵,共9个维度,平移矩阵为 3×1 的矩阵,3个维度,因此该文件中的数据维度为 N x 17 另有两个维度为光线的始发深度与终止深度,通过COLMAP输出的3D点位置计算得到),其中N为图片样本数。

除了llff数据格式,还有blender、LINEMOD和deepvoxels

数据处理

经过上述数据准备,我们得到图片、相机内外参、旋转矩阵和平移矩阵。

这里使用到计算机图形学的图形变换,将上述参数与2d的图像信息结合,将2d坐标转换为3d坐标

坐标变换

4、nerf(pytorch)
O u v O_{uv}O u v ​为像素坐标系,O u o v o O_{u_ov_o}O u o ​v o ​​为物理坐标系,首先要将物理坐标系移到像素坐标系(u o , v o u_o,v_o u o ​,v o ​分别为物理坐标系在像素坐标系下的坐标值,我们一般可以用像素图像宽高一半表示其值,即可u o = w / 2 , v o = h / 2 u_o=w/2,v_o=h/2 u o ​=w /2 ,v o ​=h /2)
4、nerf(pytorch)
下图进行相机坐标系与图像物理坐标系转换 原理介绍
O c Y c X C O_{cY_cX_C}O c Y c ​X C ​​为相机坐标系,o x y o_{xy}o x y ​为图像物理坐标系,也可以理解为相机的观景窗口
4、nerf(pytorch)
4、nerf(pytorch)
在NeRF中,进行像素坐标到相机坐标的变换,属于透视投影变换的逆变换,即2D点到3D的变化,即我们需要根据像素点的坐标(u,v)计算该点在相机坐标系下的三维的坐标(X,Y,Z)
4、nerf(pytorch)
注意:COLMAP采用的是opencv定义的相机坐标系统,其中x轴向右,y轴向下,z轴向内;而Nerf pytorch采用的是OpenGL定义的相机坐标系统,其中x轴向右,y轴向上,z轴向外。在实际应用中需要在y与z轴进行相反数转换。

然后我们要进行摄像机变换,将相机坐标拓展到世界坐标系
原理介绍

4、nerf(pytorch)
4、nerf(pytorch)

通过平移矩阵,将相机坐标系原点与世界坐标系进行对齐

4、nerf(pytorch)
通过旋转矩阵对齐坐标轴的朝向
4、nerf(pytorch)
结合上述过程,世界物理坐标转换为相机坐标总公式为:
4、nerf(pytorch)

将所有矩阵总结后得到像素坐标到新建物理坐标的转换矩阵M如下

4、nerf(pytorch)
其中 n, f 是近剪切平面和远剪切平面,r 和 t 是近剪切平面上场景的右边界和上边界。(注意,这是在约定中,摄像机是面向 -z方向的。)

输入像素3d坐标进行转换

4、nerf(pytorch)
这样一来我们便通过相机的内外参数成功地将像素坐标转换为了统一的世界坐标下的光线始发点与方向向量,即将 (X,Y,Z)与 (θ,φ) 编码成了光线的方向作为MLP的输入,所以从形式上来讲MLP的输入并不是真正的点的三维坐标(X,Y,Z),而是由像素坐标经过相机外参变换得到的光线始发点与方向向量以及不同的深度值构成的,而方位的输入也并不是 (θ,φ),而是经过标准化的光线方向。不过从本质上来讲,在一个视角下,给定光线始发点o ,特定的光线方向 d 以及不同的深度值 t ,通过文章中的公式 ray = o + td 便可以得到该光线上所有的3D点的表示 ,然后扩展到不同的光线方向便可以得到相机视锥中所有3D点的表示,再进一步扩展到不同的相机拍摄机位,理论上便可以得到该场景所有3D点的表示

NeRF的代码中其实还存在另外一个坐标系统:Normalized device coordinates(ndc)坐标系,一般对于一些forward facing的场景(景深很大)会选择进一步把世界坐标转换为ndc坐标,因为ndc坐标系中会将光线的边界限制在0-1之间以减少在景深比较大的背景下产生的负面影响。但是需要注意ndc坐标系不能和spherify pose(球状多视图采样时)同时使用,这点nerf-pytorch中并没有设置两者的互斥关系,而nerf-pl则明确设置了。

4、nerf(pytorch)

4、nerf(pytorch)
新原点 o′ 和方向 d′ 的分量必须满足
4、nerf(pytorch)
为了消除自由度,我们决定 t′ = 0 和 t = 0 应映射到同一点,得到 NDC 空间原点 o′

4、nerf(pytorch)
这正是原始射线原点的投影 π ( 0 ) \pi(0)π(0 )。通过将其代入方程 10 以代替任意 t,我们可以确定 t′ 和 d′ 的值
4、nerf(pytorch)
4、nerf(pytorch)
当 t = 0 时,t′ = 0。此外,我们看到 t′ → 1 作为 t → ∞。回到原始投影矩阵,我们的常数是
4、nerf(pytorch)
使用标准的针孔相机模型,我们可以重新参数化为
4、nerf(pytorch)

其中 W 和 H 是图像的宽度和高度(以像素为单位),f c a m f_{cam}f c a m ​ 是相机的焦距。在我们实际的正向捕获中,我们假设远场景边界是无穷大的(这使我们的成本非常低,因为NDC使用z维来表示反深度,即视差)。在此限制中,z 常量简化为

4、nerf(pytorch)
总结
4、nerf(pytorch)

; 位置编码

4、nerf(pytorch)
首先要介绍一下关于”图像”的一个概念:high-frequency (高频)和 low-frequency(低频)。标号1所处的地方是白色类似于桌布的东西,标号2所处的地方是核桃。在前者的小区域内,移动一点位置,颜色并不会变化很大,但是后者就会。前者的区域对应 low-frequency,后者的区域对应 high-frequency。

4、nerf(pytorch)
深度神经网络偏向于学习图像 low-frequency 的部分,而针对 high-frequency 的部分难以学习,因此需要将输入数据映射到高维空间

论文中使用的公式如下,L的值决定了神经网络能学习到的最高频率的大小

4、nerf(pytorch)
L 的值太小,则会导致 high-frequency 区域难以重现的问题,L 的值太大,则会导致重现出来的图像有很多噪声的情况

NeRF 的作者根据实验结果,发现关于三维点坐标 x 和 单位方向向量 d ,L分别取 10 与 4 的情况,实验效果比较好

三维重建

4、nerf(pytorch)
三维重建部分本质上是一个2D到3D的建模过程,利用图片像素点转换到3D点后的位置(X,Y,Z)及方位视角(θ,φ)作为输入,通过多层感知机(MLP)建模该点对应的颜色color(c)及体密度volume density(σ),形成了3D场景的”隐式表示”。

三维重建网络的任务就是要预测从像素点出发的射线上不同深度的采样点的体密度与颜色

4、nerf(pytorch)
论文中网络图像如下图所示
4、nerf(pytorch)
经过数据处理,我们得到了射线 o+dt,将不同的t带入公式即可得到射线上不同采样点的3d坐标(X,Y,Z)和方位视角(θ,φ)

3d坐标(X,Y,Z)经过位置编码后经过5个256维的全连接层后将256维特征向量与60维(X,Y,Z)拼接后再经过3个256维全连接层得到256维特征向量,此时,将256维特征向量经过256维全连接层输出得到体密度,与此同时,256维特征向量拼接24维度方位视角(θ,φ),分别经过283维全连接层、256维全连接层和128维全连接层得到3维特征向量RGB颜色值

总结:体密度由位置坐标预测得到,RGB颜色由位置坐标和方位视角得到

; 光线采样处理

4、nerf(pytorch)
由相机 o 和某个成像点 C 两点确定的射线。以相机为原点o ,射线方向为坐标轴方向建立坐标轴。则坐标轴上任意一点坐标可表示为 o + t d 。其中 t 为该点到原点距离,d为单位方向向量。near,far 实际上是两个垂直于射线,平行于成像平面的平面。在这里,也用near、far表示那两个平面与射线的交点。在射线上采样的范围是从 near 点到 far 点。理论上,如果对于相关的 scene 我们一无所知,near 点应该被设在在原点(相机处),far 点则在 无穷远。但是实际上如果我们要处理的是 synthetic dataset,则会根据已知的物体在 scene 中的范围调整 near 和 far。因为这样可以减少计算量。

首先由于体素渲染需要沿着光线进行积分,而积分在计算机中是以离散的乘积和进行计算的,那么这里就涉及到在光线上进行点的采样。NeRF在光线的点采样过程中的进行了一些设计。首先为了避免大量的点采样导致的计算量的激增,NeRF设计了coarse to fine的采样策略。在coarse采样阶段,采用了带有扰动的均匀采样方法

先在光线的边界之间进行深度空间的均匀采样,然后在规定了下界与上界的范围内将采样点进行扰动,最终coarse采样阶段在每条光线上采样了64个样本点。然后基于coarse采样得到的结果进一步指导fine的点采样,即在对最终颜色贡献更大(权重更大)的点附近进行更加密集的点采样

粗采样中,我们采样64个,细采样中我们根据粗采样结果进行细粒度采样128

在 NeRF中,一共有训练两个 MLP,分别是 Model c ( c o a r s e ) c_{(coarse)}c (c o a r s e )​和 Model f ( f i n e ) f_{(fine)}f (f i n e )​,它们的模型结构一样,只是参数不同。它们分别对应 Hierarchical Volume Sampling 两次抽样的模型

  • c ( c o a r s e ) c_{(coarse)}c (c o a r s e )​
    第一次抽样是先将在相机射线上,由 near 和 far 构成的范围 n 等分,然后在每个小区间内均匀采样得到一个采样点 C i C_i C i ​,最终一共 n 个采样点。将得到的采样点的位置信息和其他信息输入 MLP,MLP 输出每个采样点的 Volume Density
  • f ( f i n e ) f_{(fine)}f (f i n e )​
    第二次采样是基于第一次采样点的结果。a i a_i a i ​ 乘上 T i T_i T i ​ 得到 weight i 为 C i C_i C i ​ 对 C 的权重,那么第二次采样会对第一次采样结果中 w e i g h t i weight_i w e i g h t i ​ 较大的区域多采样,w e i g h t i weight_i w e i g h t i ​ 小的地方采样少一些

首先对 w e i g h t i weight_i w e i g h t i ​ 归一化

4、nerf(pytorch)
其中, w i w_i w i ​ 表示 w e i g h t i weight_i w e i g h t i ​ 表示归一化后的 w i w_i w i ​ ,N c N_c N c ​ 是第一次采样的采样点数目
4、nerf(pytorch)
对 w i w^i w i 构造累计分布函数(CDF)即可

渲染

4、nerf(pytorch)
渲染部分本质上是一个3D到2D的建模过程,渲染部分利用重建部分得到的3D点的颜色及不透明度沿着光线进行整合得到最终的2D图像像素值
4、nerf(pytorch)

根据三维重建得到射线上不同采样点的体密度与颜色后,我们就可以将其带入到体渲染方程,生成像素点颜色

4、nerf(pytorch)

其中函数T(t)表示射线从 t n t_n t n ​ 到 t 沿射线累积透射率,即射线从 t n t_n t n ​ 到 t 不碰到任何粒子的概率

从上述公式中,我们不难看出要对射线上的体密度、颜色等进行积分,那就得使用到离散求积法对这个连续积分进行数值估计,这会极大地限制表示的分辨率,因此可通过分层抽样方法,这样使用离散的样本估计积分,但是能够较好地表示一个连续的场景(类似重要性采样,对整个积分域进行非均匀离散化,较能还原原本的积分分布

4、nerf(pytorch)
C : 渲染的像素点
C i C_i C i ​ : 采样点
c i c_i c i ​ : 采样点颜色值
σ:‎体密度‎
δ :两个采样点之间的距离
α :透明度
  • σ:(体密度)由三维重建得出
  • δ (两个采样点之间的距离定义)

每个采样点C 1 , C 2 , … … , C n C_1,C_2,……,C_n C 1 ​,C 2 ​,……,C n ​,其到相机 0 的距离为 Z 1 , Z 2 , … … , Z n , δ 1 Z_1,Z_2,……,Z_n,δ_1 Z 1 ​,Z 2 ​,……,Z n ​,δ1 ​由Z 2 − Z 1 Z_2-Z_1 Z 2 ​−Z 1 ​得到,δ 2 δ_2 δ2 ​由Z 3 − Z 2 Z_3-Z_2 Z 3 ​−Z 2 ​得到,如此类推

我们可以定义两个list

4、nerf(pytorch)
4、nerf(pytorch)
  • α (透明度)

这里,我们定义a = 1 – exp(-σδ)

4、nerf(pytorch)
经过公式变换,T i T_i T i ​ 可以表示为
4、nerf(pytorch)
T i T_i T i ​ 可以理解为除去当前点 C i C_i C i ​ 前面的点 C 1 , C 2 , … … , C i − 1 C_1,C_2,……,C_{i-1}C 1 ​,C 2 ​,……,C i −1 ​ 对 C 的影响后,C i C_i C i ​ 能对 C 产生的最大影响系数,a i ∗ T i a_i * T_i a i ​∗T i ​ 表示 C i C_i C i ​ 对 C 的权重

不同相机射线的采样点不同,对应的 T i T_i T i ​ 不同,那么同一个三维点,从不同方向去看颜色可能不同

当 σ = 0 时,α = 0 ,这表示当 体密度 为 0 时,透明度为 0,完全透明,从相机发出的光继续向后传播,当前点对成像点的颜色无贡献,也不影响后面点对成像点颜色的贡献

当 σ → + ∞ , α → 1 ,这表示此时 透明度为1,完全不透明,使得从相机发出的光被完全阻挡在当前点,使得当前点后面的点对成像点 C 的颜色无贡献

由上面的公式可以得出离散求积公式可以写成下列形式

4、nerf(pytorch)

; 损失

4、nerf(pytorch)

在训练的时候,利用渲染部分得到的2D图像,通过与原始图片做L2损失函数(L2 Loss)进行网络优化

4、nerf(pytorch)

代码分析

源码地址:https://github.com/yenchenlin/nerf-pytorch

超参数

run_nerf.py config_parser()

  • *基本参数

    parser.add_argument('--config', is_config_file=True,
                        help='config file path')

    parser.add_argument("--expname", type=str,
                        help='experiment name')

    parser.add_argument("--basedir", type=str, default='./logs/',
                        help='where to store ckpts and logs')

    parser.add_argument("--datadir", type=str, default='./data/llff/fern',
                        help='input data directory')

  • *训练相关

    parser.add_argument("--netdepth", type=int, default=8,
                        help='layers in network')

    parser.add_argument("--netwidth", type=int, default=256,
                        help='channels per layer')
    parser.add_argument("--netdepth_fine", type=int, default=8,
                        help='layers in fine network')
    parser.add_argument("--netwidth_fine", type=int, default=256,
                        help='channels per layer in fine network')

    parser.add_argument("--N_rand", type=int, default=32*32*4,
                        help='batch size (number of random rays per gradient step)')

    parser.add_argument("--lrate", type=float, default=5e-4,
                        help='learning rate')

    parser.add_argument("--lrate_decay", type=int, default=250,
                        help='exponential learning rate decay (in 1000 steps)')

    parser.add_argument("--chunk", type=int, default=1024*32,
                        help='number of rays processed in parallel, decrease if running out of memory')

    parser.add_argument("--netchunk", type=int, default=1024*64,
                        help='number of pts sent through network in parallel, decrease if running out of memory')

    parser.add_argument("--no_batching", action='store_true',
                        help='only take random rays from 1 image at a time')

    parser.add_argument("--no_reload", action='store_true',
                        help='do not reload weights from saved ckpt')

    parser.add_argument("--ft_path", type=str, default=None,
                        help='specific weights npy file to reload for coarse network')

  • *渲染相关

    parser.add_argument("--N_samples", type=int, default=64,
                        help='number of coarse samples per ray')

    parser.add_argument("--N_importance", type=int, default=0,
                        help='number of additional fine samples per ray')

    parser.add_argument("--perturb", type=float, default=1.,
                        help='set to 0. for no jitter, 1. for jitter')
    parser.add_argument("--use_viewdirs", action='store_true',
                        help='use full 5D input instead of 3D')

    parser.add_argument("--i_embed", type=int, default=0,
                        help='set 0 for default positional encoding, -1 for none')

    parser.add_argument("--multires", type=int, default=10,
                        help='log2 of max freq for positional encoding (3D location)')

    parser.add_argument("--multires_views", type=int, default=4,
                        help='log2 of max freq for positional encoding (2D direction)')

    parser.add_argument("--raw_noise_std", type=float, default=0.,
                        help='std dev of noise added to regularize sigma_a output, 1e0 recommended')

    parser.add_argument("--render_only", action='store_true',
                        help='do not optimize, reload weights and render out render_poses path')

    parser.add_argument("--render_test", action='store_true',
                        help='render the test set instead of render_poses path')

    parser.add_argument("--render_factor", type=int, default=0,
                        help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')

  • *数据准备参数

    parser.add_argument("--precrop_iters", type=int, default=0,
                        help='number of steps to train on central crops')
    parser.add_argument("--precrop_frac", type=float,
                        default=.5, help='fraction of img taken for central crops')

    parser.add_argument("--dataset_type", type=str, default='llff',
                        help='options: llff / blender / deepvoxels')

    parser.add_argument("--testskip", type=int, default=8,
                        help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')

    parser.add_argument("--shape", type=str, default='greek',
                        help='options : armchair / cube / greek / vase')

    parser.add_argument("--white_bkgd", action='store_true',
                        help='set to render synthetic data on a white bkgd (always use for dvoxels)')
    parser.add_argument("--half_res", action='store_true',
                        help='load blender synthetic data at 400x400 instead of 800x800')

    parser.add_argument("--factor", type=int, default=8,
                        help='downsample factor for LLFF images')
    parser.add_argument("--no_ndc", action='store_true',
                        help='do not use normalized device coordinates (set for non-forward facing scenes)')
    parser.add_argument("--lindisp", action='store_true',
                        help='sampling linearly in disparity rather than depth')
    parser.add_argument("--spherify", action='store_true',
                        help='set for spherical 360 scenes')
    parser.add_argument("--llffhold", type=int, default=8,
                        help='will take every 1/N images as LLFF test set, paper uses 8')

    parser.add_argument("--i_print",   type=int, default=100,
                        help='frequency of console printout and metric loggin')
    parser.add_argument("--i_img",     type=int, default=500,
                        help='frequency of tensorboard image logging')
    parser.add_argument("--i_weights", type=int, default=10000,
                        help='frequency of weight ckpt saving')
    parser.add_argument("--i_testset", type=int, default=50000,
                        help='frequency of testset saving')
    parser.add_argument("--i_video",   type=int, default=50000,
                        help='frequency of render_poses video saving')

数据处理(blender)

load_blender.py

4、nerf(pytorch)
transforms_test.json
4、nerf(pytorch)
4、nerf(pytorch)

; 数据加载

def load_blender_data(basedir, half_res=False, testskip=1):
    splits = ['train', 'val', 'test']
    metas = {}

    for s in splits:
        with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
            metas[s] = json.load(fp)

    all_imgs = []
    all_poses = []
    counts = [0]
    for s in splits:
        meta = metas[s]
        imgs = []
        poses = []
        if s=='train' or testskip==0:
            skip = 1
        else:
            skip = testskip

        for frame in meta['frames'][::skip]:
            fname = os.path.join(basedir, frame['file_path'] + '.png')
            imgs.append(imageio.imread(fname))
            poses.append(np.array(frame['transform_matrix']))

        imgs = (np.array(imgs) / 255.).astype(np.float32)

        poses = np.array(poses).astype(np.float32)
        counts.append(counts[-1] + imgs.shape[0])
        all_imgs.append(imgs)
        all_poses.append(poses)

    i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]

    imgs = np.concatenate(all_imgs, 0)
    poses = np.concatenate(all_poses, 0)

    H, W = imgs[0].shape[:2]
    camera_angle_x = float(meta['camera_angle_x'])
    focal = .5 * W / np.tan(.5 * camera_angle_x)

    render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)

    if half_res:
        H = H//2
        W = W//2
        focal = focal/2.

        imgs_half_res = np.zeros((imgs.shape[0], H, W, 4))
        for i, img in enumerate(imgs):
            imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
        imgs = imgs_half_res

    return imgs, poses, render_poses, [H, W, focal], i_split

坐标变换

位置编码

run_nerf_helpers.py Embedder 位置编码

class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()

    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x : x)
            out_dim += d

        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']

        if self.kwargs['log_sampling']:
            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)

        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
                out_dim += d

        self.embed_fns = embed_fns
        self.out_dim = out_dim

    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

模型创建

run_nerf.py create_nerf(args) 创建模型

def create_nerf(args):
    """Instantiate NeRF's MLP model.

"""
    embed_fn, input_ch = get_embedder(args.multires, args.i_embed)

    input_ch_views = 0
    embeddirs_fn = None
    if args.use_viewdirs:
        embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
    output_ch = 5 if args.N_importance > 0 else 4
    skips = [4]

    model = NeRF(D=args.netdepth, W=args.netwidth,
                 input_ch=input_ch, output_ch=output_ch, skips=skips,
                 input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
    grad_vars = list(model.parameters())

    model_fine = None

    if args.N_importance > 0:
        model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
                          input_ch=input_ch, output_ch=output_ch, skips=skips,
                          input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
        grad_vars += list(model_fine.parameters())

    network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
                                                                embed_fn=embed_fn,
                                                                embeddirs_fn=embeddirs_fn,
                                                                netchunk=args.netchunk)

    optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))

    start = 0
    basedir = args.basedir
    expname = args.expname

    if args.ft_path is not None and args.ft_path!='None':
        ckpts = [args.ft_path]
    else:
        ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]

    print('Found ckpts', ckpts)
    if len(ckpts) > 0 and not args.no_reload:
        ckpt_path = ckpts[-1]
        print('Reloading from', ckpt_path)
        ckpt = torch.load(ckpt_path)

        start = ckpt['global_step']
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])

        model.load_state_dict(ckpt['network_fn_state_dict'])
        if model_fine is not None:
            model_fine.load_state_dict(ckpt['network_fine_state_dict'])

    render_kwargs_train = {
        'network_query_fn' : network_query_fn,
        'perturb' : args.perturb,
        'N_importance' : args.N_importance,
        'network_fine' : model_fine,
        'N_samples' : args.N_samples,
        'network_fn' : model,
        'use_viewdirs' : args.use_viewdirs,
        'white_bkgd' : args.white_bkgd,
        'raw_noise_std' : args.raw_noise_std,
    }

    if args.dataset_type != 'llff' or args.no_ndc:
        print('Not ndc!')
        render_kwargs_train['ndc'] = False
        render_kwargs_train['lindisp'] = args.lindisp

    render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.

    return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer

网络结构

NeRF(
  (pts_linears): ModuleList(
    (0): Linear(in_features=63, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): Linear(in_features=256, out_features=256, bias=True)
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): Linear(in_features=319, out_features=256, bias=True)
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): Linear(in_features=256, out_features=256, bias=True)
  )
  (views_linears): ModuleList(
    (0): Linear(in_features=283, out_features=128, bias=True)
  )
  (feature_linear): Linear(in_features=256, out_features=256, bias=True)
  (alpha_linear): Linear(in_features=256, out_features=1, bias=True)
  (rgb_linear): Linear(in_features=128, out_features=3, bias=True)
)

图像渲染

run_nerf.py raw2outputs

def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
    """Transforms model's predictions to semantically meaningful values.

    Args:
        raw: [num_rays, num_samples along ray, 4]. Prediction from model. 光线采样点模型预测结果,包含每个采样点的RGB和体密度
        z_vals: [num_rays, num_samples along ray]. Integration time.

        rays_d: [num_rays, 3]. Direction of each ray.

    Returns:
        rgb_map: [num_rays, 3]. Estimated RGB color of a ray.

        disp_map: [num_rays]. Disparity map. Inverse of depth map.

        acc_map: [num_rays]. Sum of weights along each ray.

        weights: [num_rays, num_samples]. Weights assigned to each sampled color.

        depth_map: [num_rays]. Estimated distance to object.

"""

    raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)

    dists = z_vals[...,1:] - z_vals[...,:-1]
    dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1)
    dists = dists * torch.norm(rays_d[...,None,:], dim=-1)

    rgb = torch.sigmoid(raw[...,:3])
    noise = 0.

    if raw_noise_std > 0.:
        noise = torch.randn(raw[...,3].shape) * raw_noise_std

        if pytest:
            np.random.seed(0)
            noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
            noise = torch.Tensor(noise)

    alpha = raw2alpha(raw[...,3] + noise, dists)

    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]

    rgb_map = torch.sum(weights[...,None] * rgb, -2)

    depth_map = torch.sum(weights * z_vals, -1)
    disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
    acc_map = torch.sum(weights, -1)

    if white_bkgd:
        rgb_map = rgb_map + (1.-acc_map[...,None])

    return rgb_map, disp_map, acc_map, weights, depth_map

光线采样处理

run_nerf.py render_rays

def render_rays(ray_batch,
                network_fn,
                network_query_fn,
                N_samples,
                retraw=False,
                lindisp=False,
                perturb=0.,
                N_importance=0,
                network_fine=None,
                white_bkgd=False,
                raw_noise_std=0.,
                verbose=False,
                pytest=False):
    """Volumetric rendering.

    Args:
      ray_batch: array of shape [batch_size, ...]. All information necessary
        for sampling along a ray, including: ray origin, ray direction, min
        dist, max dist, and unit-magnitude viewing direction.

      network_fn: function. Model for predicting RGB and density at each point
        in space.

      network_query_fn: function used for passing queries to network_fn.

      N_samples: int. Number of different times to sample along each ray.

      retraw: bool. If True, include model's raw, unprocessed predictions.

      lindisp: bool. If True, sample linearly in inverse depth rather than in depth.

      perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
        random points in time.

      N_importance: int. Number of additional times to sample along each ray.

        These samples are only passed to network_fine.

      network_fine: "fine" network with same spec as network_fn.

      white_bkgd: bool. If True, assume a white background.

      raw_noise_std: ...

      verbose: bool. If True, print more debugging info.

    Returns:
      rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.

      disp_map: [num_rays]. Disparity map. 1 / depth.

      acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.

      raw: [num_rays, num_samples, 4]. Raw predictions from model.

      rgb0: See rgb_map. Output for coarse model.

      disp0: See disp_map. Output for coarse model.

      acc0: See acc_map. Output for coarse model.

      z_std: [num_rays]. Standard deviation of distances along ray for each
        sample.

"""
    N_rays = ray_batch.shape[0]
    rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6]
    viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
    bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
    near, far = bounds[...,0], bounds[...,1]

    t_vals = torch.linspace(0., 1., steps=N_samples)
    if not lindisp:
        z_vals = near * (1.-t_vals) + far * (t_vals)
    else:
        z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))

    z_vals = z_vals.expand([N_rays, N_samples])

    if perturb > 0.:

        mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        upper = torch.cat([mids, z_vals[...,-1:]], -1)
        lower = torch.cat([z_vals[...,:1], mids], -1)

        t_rand = torch.rand(z_vals.shape)

        if pytest:
            np.random.seed(0)
            t_rand = np.random.rand(*list(z_vals.shape))
            t_rand = torch.Tensor(t_rand)

        z_vals = lower + (upper - lower) * t_rand

    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]

    raw = network_query_fn(pts, viewdirs, network_fn)
    rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

    if N_importance > 0:

        rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map

        z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])

        z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
        z_samples = z_samples.detach()

        z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
        pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]

        run_fn = network_fn if network_fine is None else network_fine

        raw = network_query_fn(pts, viewdirs, run_fn)

        rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

    ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
    if retraw:
        ret['raw'] = raw
    if N_importance > 0:
        ret['rgb0'] = rgb_map_0
        ret['disp0'] = disp_map_0
        ret['acc0'] = acc_map_0
        ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False)

    for k in ret:
        if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
            print(f"! [Numerical Error] {k} contains nan or inf.")

    return ret

权重累计分布

run_nerf_helpers sample_pdf

根据粗采样结果对权重较大位置加大采样

def sample_pdf(bins, weights, N_samples, det=False, pytest=False):

    weights = weights + 1e-5
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1)

    if det:
        u = torch.linspace(0., 1., steps=N_samples)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [N_samples])

    if pytest:
        np.random.seed(0)
        new_shape = list(cdf.shape[:-1]) + [N_samples]
        if det:
            u = np.linspace(0., 1., N_samples)
            u = np.broadcast_to(u, new_shape)
        else:
            u = np.random.rand(*new_shape)
        u = torch.Tensor(u)

    u = u.contiguous()
    inds = torch.searchsorted(cdf, u, right=True)
    below = torch.max(torch.zeros_like(inds-1), inds-1)
    above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)

    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = (cdf_g[...,1]-cdf_g[...,0])
    denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
    t = (u-cdf_g[...,0])/denom
    samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])

    return samples

Original: https://blog.csdn.net/weixin_50973728/article/details/126048095
Author: C–G
Title: 4、nerf(pytorch)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/624774/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球