0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

利用PyTorch实现NeRF代码详解

3D视觉工坊 来源:3DCV 2023-10-21 09:46 次阅读

作者:大森林| 来源:3DCV

1. NeRF定义

神经辐射场(NeRF)是一种利用神经网络来表示和渲染复杂的三维场景的方法。它可以从一组二维图片中学习出一个连续的三维函数,这个函数可以给出空间中任意位置和方向上的颜色和密度。通过体积渲染的技术,NeRF可以从任意视角合成出逼真的图像,包括透明和半透明物体,以及复杂的光线传播效果。

2. NeRF优势

NeRF模型相比于其他新的视图合成和场景表示方法有以下几个优势:

1)NeRF不需要离散化的三维表示,如网格或体素,因此可以避免模型精度和细节程度受到限制。NeRF也可以自适应地处理不同形状和大小的场景,而不需要人工调整参数

2)NeRF使用位置编码的方式将位置和角度信息映射到高频域,使得网络能够更好地捕捉场景的细微结构和变化。NeRF还使用视角相关的颜色预测,能够生成不同视角下不同的光照效果。

3)NeRF使用分段随机采样的方式来近似体积渲染的积分,这样可以保证采样位置的连续性,同时避免网络过拟合于离散点的信息。NeRF还使用多层级体素采样的技巧,以提高渲染效率和质量。

3. NeRF实现步骤

1)定义一个全连接的神经网络,它的输入是空间位置和视角方向,输出是颜色和密度。

2)使用位置编码的方式将输入映射到高频域,以便网络能够捕捉细微的结构和变化。

3)使用分段随机采样的方式从每条光线上采样一些点,然后用神经网络预测这些点的颜色和密度。

4)使用体积渲染的公式计算每条光线上的颜色和透明度,作为最终的图像输出。

5)使用渲染损失函数来优化神经网络的参数,使得渲染的图像与输入的图像尽可能接近。

importtorch
importtorch.nnasnn
importtorch.nn.functionalasF

#定义一个全连接的神经网络,它的输入是空间位置和视角方向,输出是颜色和密度。
classNeRF(nn.Module):
def__init__(self,D=8,W=256,input_ch=3,input_ch_views=3,output_ch=4,skips=[4]):
super().__init__()
#定义位置编码后的位置信息的线性层,如果层数在skips列表中,则将原始位置信息与隐藏层拼接
self.pts_linears=nn.ModuleList(
[nn.Linear(input_ch,W)]+[nn.Linear(W,W)ifinotinskipselsenn.Linear(W+input_ch,W)foriinrange(D-1)])
#定义位置编码后的视角方向信息的线性层
self.views_linears=nn.ModuleList([nn.Linear(W+input_ch_views,W//2)]+[nn.Linear(W//2,W//2)foriinrange(1)])
#定义特征向量的线性层
self.feature_linear=nn.Linear(W//2,W)
#定义透明度(alpha)值的线性层
self.alpha_linear=nn.Linear(W,1)
#定义RGB颜色的线性层
self.rgb_linear=nn.Linear(W+input_ch_views,3)

defforward(self,x):
#x:(B,input_ch+input_ch_views)
#提取位置和视角方向信息
p=x[:,:3]#(B,3)
d=x[:,3:]#(B,3)

#对输入进行位置编码,将低频信号映射到高频域
p=positional_encoding(p)#(B,input_ch)
d=positional_encoding(d)#(B,input_ch_views)

#将位置信息输入网络
h=p
fori,linenumerate(self.pts_linears):
h=l(h)
h=F.relu(h)
ifiinskips:
h=torch.cat([h,p],-1)#如果层数在skips列表中,则将原始位置信息与隐藏层拼接

#将视角方向信息与隐藏层拼接,并输入网络
h=torch.cat([h,d],-1)
fori,linenumerate(self.views_linears):
h=l(h)
h=F.relu(h)

#预测特征向量和透明度(alpha)值
feature=self.feature_linear(h)#(B,W)
alpha=self.alpha_linear(feature)#(B,1)

#使用特征向量和视角方向信息预测RGB颜色
rgb=torch.cat([feature,d],-1)
rgb=self.rgb_linear(rgb)#(B,3)

returntorch.cat([rgb,alpha],-1)#(B,4)

#定义位置编码函数
defpositional_encoding(x):
#x:(B,C)
B,C=x.shape
L=int(C//2)#计算位置编码的长度
freqs=torch.logspace(0.,L-1,steps=L).to(x.device)*math.pi#计算频率系数,呈指数增长
freqs=freqs[None].repeat(B,1)#(B,L)
x_pos_enc_low=torch.sin(x[:,:L]*freqs)#对前一半的输入进行正弦变换,得到低频部分(B,L)
x_pos_enc_high=torch.cos(x[:,:L]*freqs)#对前一半的输入进行余弦变换,得到高频部分(B,L)
x_pos_enc=torch.cat([x_pos_enc_low,x_pos_enc_high],dim=-1)#将低频和高频部分拼接,得到位置编码后的输入(B,C)
returnx_pos_enc

#定义体积渲染函数
defvolume_rendering(rays_o,rays_d,model):
#rays_o:(B,3),每条光线的起点
#rays_d:(B,3),每条光线的方向
B=rays_o.shape[0]

#在每条光线上采样一些点
near,far=0.,1.#近平面和远平面
N_samples=64#每条光线的采样数
t_vals=torch.linspace(near,far,N_samples).to(rays_o.device)#(N_samples,)
t_vals=t_vals.expand(B,N_samples)#(B,N_samples)
z_vals=near*(1.-t_vals)+far*t_vals#计算每个采样点的深度值(B,N_samples)
z_vals=z_vals.unsqueeze(-1)#(B,N_samples,1)
pts=rays_o.unsqueeze(1)+rays_d.unsqueeze(1)*z_vals#计算每个采样点的空间位置(B,N_samples,3)

#将采样点和视角方向输入网络
pts_flat=pts.reshape(-1,3)#(B*N_samples,3)
rays_d_flat=rays_d.unsqueeze(1).expand(-1,N_samples,-1).reshape(-1,3)#(B*N_samples,3)
x_flat=torch.cat([pts_flat,rays_d_flat],-1)#(B*N_samples,6)
y_flat=model(x_flat)#(B*N_samples,4)
y=y_flat.reshape(B,N_samples,4)#(B,N_samples,4)

#提取RGB颜色和透明度(alpha)值
rgb=y[...,:3]#(B,N_samples,3)
alpha=y[...,3]#(B,N_samples)

#计算每个采样点的权重
dists=torch.cat([z_vals[...,1:]-z_vals[...,:-1],torch.tensor([1e10]).to(z_vals.device).expand(B,1)],-1)#计算相邻采样点之间的距离,最后一个距离设为很大的值(B,N_samples)
alpha=1.-torch.exp(-alpha*dists)#计算每个采样点的不透明度,即1减去透明度的指数衰减(B,N_samples)
weights=alpha*torch.cumprod(torch.cat([torch.ones((B,1)).to(alpha.device),1.-alpha+1e-10],-1),-1)[:,:-1]#计算每个采样点的权重,即不透明度乘以之前所有采样点的透明度累积积,最后一个权重设为0(B,N_samples)

#计算每条光线的最终颜色和透明度
rgb_map=torch.sum(weights.unsqueeze(-1)*rgb,-2)#加权平均每个采样点的RGB颜色,得到每条光线的颜色(B,3)
depth_map=torch.sum(weights*z_vals.squeeze(-1),-1)#加权平均每个采样点的深度值,得到每条光线的深度(B,)
acc_map=torch.sum(weights,-1)#累加每个采样点的权重,得到每条光线的不透明度(B,)

returnrgb_map,depth_map,acc_map

#定义渲染损失函数
defrendering_loss(rgb_map_pred,rgb_map_gt):
return((rgb_map_pred-rgb_map_gt)**2).mean()#计算预测的颜色与真实颜色之间的均方误差

综上所述,本代码实现了NeRF的核心结构,具体实现内容包括以下四个部分。

1)定义了NeRF网络结构,包含位置编码和多层全连接网络,输入是位置和视角,输出是颜色和密度。

2)实现了位置编码函数,通过正弦和余弦变换引入高频信息。

3)实现了体积渲染函数,在光线上采样点,查询NeRF网络预测颜色和密度,然后通过加权平均实现整体渲染。

4)定义了渲染损失函数,计算预测颜色和真实颜色的均方误差。

当然,本方案只是实现NeRF的一个基础方案,更多的细节还需要进行优化。

当然,为了方便下载,我们已经将上述两个源代码打包好了。

审核编辑:汤梓红

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4572

    浏览量

    98746
  • 函数
    +关注

    关注

    3

    文章

    3868

    浏览量

    61309
  • 代码
    +关注

    关注

    30

    文章

    4555

    浏览量

    66767
  • pytorch
    +关注

    关注

    2

    文章

    761

    浏览量

    12831

原文标题:一文带你入门NeRF:利用PyTorch实现NeRF代码详解(附代码)

文章出处:【微信号:3D视觉工坊,微信公众号:3D视觉工坊】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    Image Style Transfer pytorch方式实现的主要思路

    深度学总结:Image Style Transfer pytorch方式实现,这个是非基于autoencoder和domain adversrial方式
    发表于 06-20 10:58

    PyTorch如何入门

    PyTorch 入门实战(一)——Tensor
    发表于 06-01 09:58

    Pytorch代码移植嵌入式开发笔记,错过绝对后悔

    @[TOC]Pytorch 代码移植嵌入式开发笔记目前在做开发完成后的AI模型移植到前端的工作。 由于硬件设施简陋,需要把代码和算法翻译成基础加乘算法并输出每个环节参数。记录几点实用技巧以及项目
    发表于 11-08 08:24

    单片机点灯的基本语法代码详解

    【单片机】点灯基本语法代码详解代码详解#include #include //功能:实现P1口左移#define uchar unsigne
    发表于 02-16 06:34

    Caffe2 和 PyTorch 代码层合并旨为提高开发效率

    按照贾扬清的说法,Facebook 去年启动 ONNX 项目并组建团队时,就已经开始推动 Caffe2 和 PyTorch代码层的合并。
    的头像 发表于 04-30 09:16 3349次阅读

    Pytorch 1.1.0,来了!

    许多用户已经转向使用标准PyTorch运算符编写自定义实现,但是这样的代码遭受高开销:大多数PyTorch操作在GPU上启动至少一个内核,并且RNN由于其重复性质通常运行许多操作。但是
    的头像 发表于 05-05 10:02 5679次阅读
    <b class='flag-5'>Pytorch</b> 1.1.0,来了!

    Pytorch实现MNIST手写数字识别

    Pytorch 实现MNIST手写数字识别
    发表于 06-16 14:47 6次下载

    一种全新的数据蒸馏方法来加速NeRF

    , 2021]. 尽管加速很可观 (如 [Yu et al., ICCV, 2021] 实现了 3000x 的渲染加速), 但这种数据结构也破坏了 NeRF 作为场景表征存储小的优点。
    的头像 发表于 08-08 10:53 1107次阅读

    pytorch实现断电继续训练时需要注意的要点

    本文整理了pytorch实现断电继续训练时需要注意的要点,附有代码详解
    的头像 发表于 08-22 09:50 1080次阅读

    介绍一种神经场成对配准的技术NeRF2NeRF

    我们介绍了一种神经场成对配准的技术,它扩展了基于优化的经典局部配准(即ICP)以操作神经辐射场(NeRF)。
    的头像 发表于 02-20 10:29 421次阅读

    那些年在pytorch上踩过的坑

    今天又发现了一个pytorch的小坑,给大家分享一下。手上两份同一模型的代码,一份用tensorflow写的,另一份是我拿pytorch写的,模型架构一模一样,预处理数据的逻辑也一模一样,测试发现模型推理的速度也差不多。一份预处
    的头像 发表于 02-22 14:18 823次阅读
    那些年在<b class='flag-5'>pytorch</b>上踩过的坑

    NeRF2NeRF神经辐射场的配对配准介绍

    我们介绍了一种神经场成对配准的技术,它扩展了基于优化的经典局部配准(即ICP)以操作神经辐射场(NeRF)。
    的头像 发表于 03-31 16:49 627次阅读

    [源代码]Python算法详解

    [源代码]Python算法详解[源代码]Python算法详解
    发表于 06-06 17:50 0次下载

    基于NeRF的隐式GAN架构

    一小部分2D图像合成复杂3D场景的新视图方面提供了最先进的质量。 作者提出了一个生成模型HyperNeRFGAN,它使用超网络范式来生成由NeRF表示的三维物体。超网络被定义为为解决特定任务的单独目标网络生成权值的神经模型。基于GAN的模型,利用超网络范式将高斯噪
    的头像 发表于 06-14 10:16 740次阅读
    基于<b class='flag-5'>NeRF</b>的隐式GAN架构

    TorchFix:基于PyTorch代码静态分析

    TorchFix是我们最近开发的一个新工具,旨在帮助PyTorch用户维护健康的代码库并遵循PyTorch的最佳实践。首先,我想要展示一些我们努力解决的问题的示例。
    的头像 发表于 12-18 15:20 740次阅读