Mobile wallpaper 1
2858 words
14 minutes
[Instant NGP Code Digest - A First Principle Perspective] Core Training Pipeline
UPDATE LOG
  • complete the kernel Generate Training Samples NeRF
NOTE

This draft is base on the commit d64e353db28109a81657879fc88025713d8fad53 (Oct 8, 2025)

Instant-NGP Official Repository:

NVlabs
/
instant-ngp
Waiting for api.github.com...
00K
0K
0K
Waiting...

1. Introduction and Motivation#

In this article, we’re going to untangle the core training pipeline, in a first-principle manner, and finally rewrite a clean, tidy, modern, and easy-to-understand version, and achieve better performance.

2. Kernel: Generate Training Samples NeRF#

Location: instant-ngp/src/testbed_nerf.cu -> generate_training_samples_nerf

Parameters List

NameType
n_raysuint32_t
aabbBoundingBox (custom struct)
max_samplesuint32_t
n_rays_totaluint32_t
rngdefault_rng_t (PCG RNG wrapper)
ray_counteruint32_t*
numsteps_counteruint32_t*
ray_indices_outuint32_t*
rays_out_unnormalizedRay* (custom ray structure pointer)
numsteps_outuint32_t*
coords_outPitchedPtr (pitched GPU buffer of sampled coords)
n_training_imagesuint32_t
metadataTrainingImageMetadata* (per-image camera + rays + focal info)
training_xformsTrainingXForm* (start/end view transform, rolling shutter)
density_gridconst uint8_t*
max_mipuint32_t
max_level_rand_trainingbool
max_level_ptrfloat*
snap_to_pixel_centersbool
train_envmapbool
cone_angle_constantfloat
distortionBuffer2DView (image-space distortion LUT)
cdf_x_cond_yconst float*
cdf_yconst float*
cdf_imgconst float*
cdf_resivec2 (2D integer vector)
extra_dims_gpuconst float*
n_extra_dimsuint32_t
Click here to show complete code
__global__ void generate_training_samples_nerf(
const uint32_t n_rays, BoundingBox aabb, const uint32_t max_samples, const uint32_t n_rays_total,
default_rng_t rng, uint32_t* __restrict__ ray_counter, uint32_t* __restrict__ numsteps_counter,
uint32_t* __restrict__ ray_indices_out, Ray* __restrict__ rays_out_unnormalized,
uint32_t* __restrict__ numsteps_out, PitchedPtr<NerfCoordinate> coords_out, const uint32_t n_training_images,
const TrainingImageMetadata* __restrict__ metadata, const TrainingXForm* training_xforms,
const uint8_t* __restrict__ density_grid, uint32_t max_mip, bool max_level_rand_training,
float* __restrict__ max_level_ptr, bool snap_to_pixel_centers, bool train_envmap, float cone_angle_constant,
Buffer2DView<const vec2> distortion, const float* __restrict__ cdf_x_cond_y, const float* __restrict__ cdf_y,
const float* __restrict__ cdf_img, const ivec2 cdf_res, const float* __restrict__ extra_dims_gpu,
uint32_t n_extra_dims)
{
const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= n_rays)
{
return;
}
uint32_t img = image_idx(i, n_rays, n_rays_total, n_training_images, cdf_img);
ivec2 resolution = metadata[img].resolution;
rng.advance(i * N_MAX_RANDOM_SAMPLES_PER_RAY());
vec2 uv =
nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, cdf_res, img);
// Negative values indicate masked-away regions
size_t pix_idx = pixel_idx(uv, resolution, 0);
if (read_rgba(uv, resolution, metadata[img].pixels, metadata[img].image_data_type).x < 0.0f)
{
return;
}
float max_level = max_level_rand_training ? (random_val(rng) * 2.0f)
: 1.0f; // Multiply by 2 to ensure 50% of training is at max level
float motionblur_time = random_val(rng);
const vec2 focal_length = metadata[img].focal_length;
const vec2 principal_point = metadata[img].principal_point;
const float* extra_dims = extra_dims_gpu + img * n_extra_dims;
const Lens lens = metadata[img].lens;
const mat4x3 xform =
get_xform_given_rolling_shutter(training_xforms[img], metadata[img].rolling_shutter, uv, motionblur_time);
Ray ray_unnormalized;
const Ray* rays_in_unnormalized = metadata[img].rays;
if (rays_in_unnormalized)
{
// Rays have been explicitly supplied. Read them.
ray_unnormalized = rays_in_unnormalized[pix_idx];
/* DEBUG - compare the stored rays to the computed ones
const mat4x3 xform = get_xform_given_rolling_shutter(training_xforms[img], metadata[img].rolling_shutter,
uv, 0.f); Ray ray2; ray2.o = xform[3]; ray2.d = f_theta_distortion(uv, principal_point, lens); ray2.d =
(xform.block<3, 3>(0, 0) * ray2.d).normalized(); if (i==1000) { printf("\n%d uv %0.3f,%0.3f pixel
%0.2f,%0.2f transform from [%0.5f %0.5f %0.5f] to [%0.5f %0.5f %0.5f]\n" " origin [%0.5f %0.5f %0.5f] vs
[%0.5f %0.5f %0.5f]\n" " direction [%0.5f %0.5f %0.5f] vs [%0.5f %0.5f %0.5f]\n" , img,uv.x, uv.y,
uv.x*resolution.x, uv.y*resolution.y,
training_xforms[img].start[3].x,training_xforms[img].start[3].y,training_xforms[img].start[3].z,
training_xforms[img].end[3].x,training_xforms[img].end[3].y,training_xforms[img].end[3].z,
ray_unnormalized.o.x,ray_unnormalized.o.y,ray_unnormalized.o.z,
ray2.o.x,ray2.o.y,ray2.o.z,
ray_unnormalized.d.x,ray_unnormalized.d.y,ray_unnormalized.d.z,
ray2.d.x,ray2.d.y,ray2.d.z);
}
*/
}
else
{
ray_unnormalized = uv_to_ray(0, uv, resolution, focal_length, xform, principal_point, vec3(0.0f), 0.0f,
1.0f, 0.0f, {}, {}, lens, distortion);
if (!ray_unnormalized.is_valid())
{
ray_unnormalized = {xform[3], xform[2]};
}
}
vec3 ray_d_normalized = normalize(ray_unnormalized.d);
vec2 tminmax = aabb.ray_intersect(ray_unnormalized.o, ray_d_normalized);
float cone_angle = calc_cone_angle(dot(ray_d_normalized, xform[2]), focal_length, cone_angle_constant);
// The near distance prevents learning of camera-specific fudge right in front of the camera
tminmax.x = fmaxf(tminmax.x, 0.0f);
float startt = advance_n_steps(tminmax.x, cone_angle, random_val(rng));
vec3 idir = vec3(1.0f) / ray_d_normalized;
// first pass to compute an accurate number of steps
uint32_t j = 0;
float t = startt;
vec3 pos;
while (aabb.contains(pos = ray_unnormalized.o + t * ray_d_normalized) && j < NERF_STEPS())
{
float dt = calc_dt(t, cone_angle);
uint32_t mip = mip_from_dt(dt, pos, max_mip);
if (density_grid_occupied_at(pos, density_grid, mip))
{
++j;
t += dt;
}
else
{
t = advance_to_next_voxel(t, cone_angle, pos, ray_d_normalized, idir, mip);
}
}
if (j == 0 && !train_envmap)
{
return;
}
uint32_t numsteps = j;
uint32_t base = atomicAdd(numsteps_counter, numsteps); // first entry in the array is a counter
if (base + numsteps > max_samples)
{
return;
}
coords_out += base;
uint32_t ray_idx = atomicAdd(ray_counter, 1);
ray_indices_out[ray_idx] = i;
rays_out_unnormalized[ray_idx] = ray_unnormalized;
numsteps_out[ray_idx * 2 + 0] = numsteps;
numsteps_out[ray_idx * 2 + 1] = base;
vec3 warped_dir = warp_direction(ray_d_normalized);
t = startt;
j = 0;
while (aabb.contains(pos = ray_unnormalized.o + t * ray_d_normalized) && j < numsteps)
{
float dt = calc_dt(t, cone_angle);
uint32_t mip = mip_from_dt(dt, pos, max_mip);
if (density_grid_occupied_at(pos, density_grid, mip))
{
coords_out(j)->set_with_optional_extra_dims(warp_position(pos, aabb), warped_dir, warp_dt(dt),
extra_dims, coords_out.stride_in_bytes);
++j;
t += dt;
}
else
{
t = advance_to_next_voxel(t, cone_angle, pos, ray_d_normalized, idir, mip);
}
}
if (max_level_rand_training)
{
max_level_ptr += base;
for (j = 0; j < numsteps; ++j)
{
max_level_ptr[j] = max_level;
}
}
}

2.1 CUDA indexing formula#

const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= n_elements)
{
return;
}

2.1.1 Global Thread Index [1]#

i=threadIdxx+blockIdxxblockDimx\boxed{ i = \text{threadIdx}_x + \text{blockIdx}_x \cdot \text{blockDim}_x }

CUDA Indexing

TermMeaning
threadIdxx\text{threadIdx}_xindex of thread inside its block
blockIdxx\text{blockIdx}_xindex of block inside the grid
blockDimx\text{blockDim}_xnumber of threads per block

2.2 Determine image index for a given ray#

uint32_t img = image_idx(i, n_rays, n_rays_total, n_training_images, cdf_img);

2.2.1 CUDA Function image_idx#

inline NGP_HOST_DEVICE uint32_t image_idx(uint32_t base_idx, uint32_t n_rays, uint32_t n_rays_total, uint32_t n_training_images, const float* __restrict__ cdf = nullptr, float* __restrict__ pdf = nullptr) {
if (cdf) {
float sample = ld_random_val(base_idx/* + n_rays_total*/, 0xdeadbeef);
// float sample = random_val(base_idx/* + n_rays_total*/);
uint32_t img = binary_search(sample, cdf, n_training_images);
if (pdf) {
float prev = img > 0 ? cdf[img-1] : 0.0f;
*pdf = (cdf[img] - prev) * n_training_images;
}
return img;
}
// return ((base_idx/* + n_rays_total*/) * 56924617 + 96925573) % n_training_images;
// Neighboring threads in the warp process the same image. Increases locality.
if (pdf) {
*pdf = 1.0f;
}
return (((base_idx/* + n_rays_total*/) * n_training_images) / n_rays) % n_training_images;
}
ParameterTypeNote
base_idxuint32_tUnique ray/thread index used for hashing image selection
n_raysuint32_tTotal rays scheduled in current iteration (controls uniform mapping)
n_rays_totaluint32_t(Unused in training — relevance removed)
n_training_imagesuint32_tNumber of images available for sampling (upper bound of output index)
cdfconst float*Optional CDF for importance sampling — always nullptr in NeRF training
pdffloat*Output for probability weight only used when cdf != nullptr (never touched in training)

2.2.2 Base Version#

__device__ uint32_t image_idx(
const uint32_t base_idx,
const uint32_t n_rays,
const uint32_t n_training_images
) {
return base_idx * n_training_images / n_rays % n_training_images;
}

Intuitive interpretation: Each image receives approximately NRNI\frac{N_R}{N_I} rays. Rays are distributed proportionally among the images.

f(i)=(iNINR)modNIf(i) = \Biggl(\left\lfloor \frac{i \cdot N_I}{N_R} \right\rfloor \Biggr) \bmod N_I

Where:

SymbolCorresponding variable
iibase_idx
NRN_Rn_rays
NIN_In_training_images

2.2.3 CDF and PDF#

For more details about CDF and PDF, please refer to Appendix CDF: Cumulative Distribution Function.

TODO: explain the CDF & PDF branch


2.3 Get Image Resolution#

ivec2 resolution = metadata[img].resolution;

2.3.1 TrainingImageMetadata Struct#

struct TrainingImageMetadata {
// Camera intrinsics and additional data associated with a NeRF training image
// the memory to back the pixels and rays is held by GPUMemory objects in the NerfDataset and copied here.
const void* pixels = nullptr;
EImageDataType image_data_type = EImageDataType::Half;
const float* depth = nullptr;
const Ray* rays = nullptr;
Lens lens = {};
ivec2 resolution = ivec2(0);
vec2 principal_point = vec2(0.5f);
vec2 focal_length = vec2(1000.f);
vec4 rolling_shutter = vec4(0.0f);
vec3 light_dir = vec3(0.f); // TODO: replace this with more generic float[] of task-specific metadata.
};
FieldTypeMeaning
pixelsconst void*Pointer to pixel buffer in GPU memory
image_data_typeEImageDataTypePixel storage format (Byte/Half Float etc.)
depthconst float*Optional depth values per pixel (nullable)
raysconst Ray*Optional precomputed rays (nullable)
lensLensLens configuration (distortion and optical parameters)
resolutionivec2Image width & height
principal_pointvec2Camera optical center offset
focal_lengthvec2Focal length fx, fy
rolling_shuttervec4Rolling shutter timing & motion model
light_dirvec3View lighting direction (non-general metadata placeholder)

2.3.2 How to compute resolution#

NOTE

In order to avid being trapped in the endless details of image loading and preprocessing, we now assume the image resolution is precomputed and stored in the TrainingImageMetadata struct. We will cover the image loading and preprocessing in a future article.

Here, for NeRF Synthetic dataset, we can simply assume the resolution is a constant (800 x 800). (Obviously, it’s a safe assumption that all images in the dataset share the same resolution and never change during training.)


2.4 Advance RNG State#

rng.advance(i * N_MAX_RANDOM_SAMPLES_PER_RAY());

2.4.1 default_rng_t (tcnn::pcg32) Struct#

default_rng_t (tcnn::pcg32) is a wrapper around the PCG Random Number Generator. For more details about PCG, please refer to Appendix PCG: Permuted Congruential Generator.

2.4.2 Why advance RNG state?#

In tiny-cuda-nn / instant-ngp, each CUDA thread generates one ray:

const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; // unique per thread
rng.advance(i * N_MAX_RANDOM_SAMPLES_PER_RAY());

This line is not random — it is a design requirement.

PCG produces a sequence of numbers. A PCG generator is deterministic:

xn+1=f(xn)x_{n+1} = f(x_n)

If all threads start with the same RNG state, then:

ThreadRNG values
Thread 00.83, 0.21, 0.55, …
Thread 10.83, 0.21, 0.55, …
Thread 20.83, 0.21, 0.55, …
  • every pixel ray gets the same random samples
  • training collapses (all rays identical → no learning)

2.4.3 How advance(k) jumps ahead in the PCG sequence#

The function:

rng.advance(K);

mathematically means:

xn+K=f(K)(x0)x_{n+K} = f^{(K)}(x_0)

It fast-forwards the PCG stream without generating intermediate numbers.

Each thread should get different random numbers, so they offset the RNG state using the thread ID:

That means:

Thread iRNG will begin at position
0base + 0×\timesstride
1base + 1×\timesstride
2base + 2×\timesstride

Therefore:

  • no collision
  • parallel-safe randomness
  • deterministic reproducibility

2.4.4 Why multiply by N_MAX_RANDOM_SAMPLES_PER_RAY()?#

Because each ray will generate up to that many random numbers. So they space threads far enough apart so streams don’t overlap.

If worst case = 64 random samples per ray, then:

Ray index iRNG range reserved
00–63
164–127
2128–191

Each ray lives in its own section of the RNG sequence.


2.5 Sample Image Position#

vec2 uv = nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, cdf_res, img);

2.5.1 CUDA Function nerf_random_image_pos_training#

inline __device__ vec2 nerf_random_image_pos_training(default_rng_t& rng, const ivec2& resolution, bool snap_to_pixel_centers, const float* __restrict__ cdf_x_cond_y, const float* __restrict__ cdf_y, const ivec2& cdf_res, uint32_t img, float* __restrict__ pdf = nullptr) {
vec2 uv = random_val_2d(rng);
if (cdf_x_cond_y) {
uv = sample_cdf_2d(uv, img, cdf_res, cdf_x_cond_y, cdf_y, pdf);
} else {
// // Warp-coherent tile
// uv.x = __shfl_sync(0xFFFFFFFF, uv.x, 0);
// uv.y = __shfl_sync(0xFFFFFFFF, uv.y, 0);
// const ivec2 TILE_SIZE = {8, 4};
// uv = (uv * vec2(resolution - TILE_SIZE) + vec2(tcnn::lane_id() % TILE_SIZE.x, tcnn::lane_id() / threadIdx.x)) / vec2(resolution);
if (pdf) {
*pdf = 1.0f;
}
}
if (snap_to_pixel_centers) {
uv = (vec2(clamp(ivec2(uv * vec2(resolution)), 0, resolution - 1)) + 0.5f) / vec2(resolution);
}
return uv;
}
ParameterTypeNote
rngdefault_rng_t&Random number generator reference — mutated each call
resolutionivec2Image width/height used to scale UV coordinates
snap_to_pixel_centersboolIf true, UV snapped to pixel center rather than continuous sampling
cdf_x_cond_yconst float* (optional)X-conditioned CDF table for importance sampling — if non-null enables 2D CDF sampling
cdf_yconst float* (optional)Marginal distribution along Y axis for CDF sampling
cdf_resivec2Resolution of CDF grid (width,height) corresponding to cdf_x_cond_y/cdf_y
imguint32_tImage index — determines which image’s CDF to sample from
pdffloat* (optional)Output probability density — set only if CDF sampling used or PDF requested

2.5.2 Base Version#

__device__ tcnn::vec2 nerf_random_image_pos_training(
tcnn::pcg32& rng,
const tcnn::ivec2& resolution,
const bool snap_to_pixel_centers
) {
tcnn::vec2 uv = {rng.next_float(), rng.next_float()};
if (snap_to_pixel_centers) {
uv = (tcnn::vec2(tcnn::clamp(tcnn::ivec2(uv * tcnn::vec2(resolution)), 0, resolution - 1)) + 0.5f) / tcnn::vec2(resolution);
}
return uv;
}

It generates a random UV coordinate inside a training image. UV is normalized to [0,1] × [0,1]. This UV is later turned into a ray shooting into the NeRF scene.

2.5.3 The Key Takeaway#

NeRF Synthetic training uses the else branch almost always.

Meaning:

  • UV is uniformly random
  • PDF defaults to 1.0
  • CDF importance sampling is disabled by default

2.6 Get Pixel Index#

size_t pix_idx = pixel_idx(uv, resolution, 0);

2.6.1 CUDA Function pixel_idx#

inline NGP_HOST_DEVICE ivec2 image_pos(const vec2& pos, const ivec2& resolution)
{
return clamp(ivec2(pos * vec2(resolution)), 0, resolution - 1);
}
inline NGP_HOST_DEVICE uint64_t pixel_idx(const ivec2& px, const ivec2& resolution, uint32_t img)
{
return px.x + px.y * resolution.x + img * (uint64_t)resolution.x * resolution.y;
}
inline NGP_HOST_DEVICE uint64_t pixel_idx(const vec2& uv, const ivec2& resolution, uint32_t img)
{
return pixel_idx(image_pos(uv, resolution), resolution, img);
}

2.6.2 Base Version#

inline __device__ uint64_t pixel_idx(const tcnn::vec2& uv, const tcnn::ivec2& resolution, uint32_t img) {
tcnn::ivec2 px = tcnn::clamp(tcnn::ivec2(uv * tcnn::vec2(resolution)), 0, resolution - 1);
return px.x + px.y * resolution.x + img * (uint64_t) resolution.x * resolution.y;
}

They map uv (float normalized coordinates)pixel(x,y) → flat pixel index in entire dataset

CUDA Function image_pos()#

inline NGP_HOST_DEVICE ivec2 image_pos(const vec2& pos, const ivec2& resolution)
{
return clamp(ivec2(pos * vec2(resolution)), 0, resolution - 1);
}

Input

  • pos = uv ∈ [0,1] (normalized image space)
  • resolution = (W,H)

What it does

  1. pos * resolution converts normalized UV → pixel space Example → (0.2,0.5) * (800,800) → (160,400)

  2. Convert to integer ivec2(...) (drop decimals)

  3. clamp(..., 0, resolution-1) ensures pixel cannot go outside image

Output

A valid pixel coordinate (x,y) inside the image:

0 ≤ x < width
0 ≤ y < height

CUDA Function pixel_idx(px)#

inline NGP_HOST_DEVICE uint64_t pixel_idx(const ivec2& px, const ivec2& resolution, uint32_t img)
{
return px.x + px.y * resolution.x + img * (uint64_t)resolution.x * resolution.y;
}

Meaning

This converts pixel index (x,y) + image number (img) into a 1D index for flattened dataset storage.

Breakdown:

pixel offset in image=x+ywidth\text{pixel offset in image} = x + y \cdot \text{width}image offset=img(widthheight)\text{image offset} = img \cdot (width \cdot height)

So total index = index inside image + offset to image block


CUDA Function pixel_idx(uv) — UV version#

inline NGP_HOST_DEVICE uint64_t pixel_idx(const vec2& uv, const ivec2& resolution, uint32_t img)
{
return pixel_idx(image_pos(uv, resolution), resolution, img);
}

This is just a convenience overload:

Steps internally:

uv → pixel(x,y) using image_pos()
(x,y,img) → 1D index using pixel_idx()

So this lets you write:

pixel_idx(uv, resolution, img);

instead of:

ivec2 px = image_pos(uv, resolution);
pixel_idx(px, resolution, img);

2.7 Check Pixel Validity#

if (read_rgba(uv, resolution, metadata[img].pixels, metadata[img].image_data_type).x < 0.0f)
{
return;
}

Given a pixel coordinate (either uv or integer px), look into GPU image data and return a vec4(R,G,B,A) in linear RGB space.

2.7.1 CUDA Function read_rgba#

inline NGP_HOST_DEVICE vec4 read_rgba(ivec2 px, const ivec2& resolution, const void* pixels,
EImageDataType image_data_type, uint32_t img = 0)
{
switch (image_data_type)
{
default:
// This should never happen. Bright red to indicate this.
return vec4{5.0f, 0.0f, 0.0f, 1.0f};
case EImageDataType::Byte:
{
uint32_t val = ((uint32_t*)pixels)[pixel_idx(px, resolution, img)];
if (val == 0x00FF00FF)
{
return vec4(-1.0f);
}
vec4 result = rgba32_to_rgba(val);
result.rgb() = srgb_to_linear(result.rgb()) * result.a;
return result;
}
case EImageDataType::Half:
{
__half val[4];
*(uint64_t*)&val[0] = ((uint64_t*)pixels)[pixel_idx(px, resolution, img)];
return vec4{(float)val[0], (float)val[1], (float)val[2], (float)val[3]};
}
case EImageDataType::Float:
return ((vec4*)pixels)[pixel_idx(px, resolution, img)];
}
}
inline NGP_HOST_DEVICE vec4 read_rgba(vec2 pos, const ivec2& resolution, const void* pixels,
EImageDataType image_data_type, uint32_t img = 0)
{
return read_rgba(image_pos(pos, resolution), resolution, pixels, image_data_type, img);
}

It supports three image formats:

FormatStored AsPer ChannelExplanation
Byteuint32_t8-bit RGBAsRGB → linear conversion with premultiply
Half__half[4] → packed in uint64_t16-bit floatNo conversion, read directly
Floatvec4*32-bit floatPure float, no transformation

Case 1: Byte — 8-bit texture stored as uint32_t#

uint32_t val = ((uint32_t*)pixels)[pixel_idx(px, resolution, img)];
if (val == 0x00FF00FF) {
return vec4(-1.0f);
}
vec4 result = rgba32_to_rgba(val);
result.rgb() = srgb_to_linear(result.rgb()) * result.a;
return result;

Explanation:

  1. Load 4 × 8-bit channels in one uint32_t
  2. If value is 0x00FF00FF, treat pixel as masked → returns vec4(-1) meaning invalid pixel
  3. Convert BGRA/ARGB → linear RGBA (rgba32_to_rgba)
  4. Convert sRGB → linear + premultiply by alpha

This format comes from NeRF synthetic datasets.

Case 2: Half — 16-bit floating point (stored compact)#

__half val[4];
*(uint64_t*)&val[0] = ((uint64_t*)pixels)[pixel_idx(...)]
return vec4{(float)val[0], ... }

Breakdown:

Stored asRead asWhy?
4×half (each 16-bit) = 8 bytesuint64_t loadfaster & coalesced
Then reinterpret as __half[4]convert to floatfor computation

Used for lighter GPU memory footprint with HDR capability.

Case 3: Float — direct vec4#

return ((vec4*)pixels)[pixel_idx(px, resolution, img)];

Fastest — no conversion. Used when training directly with float images.

Second overload — UV input#

inline vec4 read_rgba(vec2 pos, ...)
{
return read_rgba(image_pos(pos, resolution), ...);
}

Meaning:

  1. Convert u,v[0,1){u,v} \in [0,1) → pixel coordinate image_pos() = uv * resolution → clamp to image bounds
  2. Call integer version

2.7.2 Base Version#

inline __device__ float srgb_to_linear(float x) {
return (x <= 0.04045f) ? (x * (1.f / 12.92f)) : powf((x + 0.055f) * (1.f / 1.055f), 2.4f);
}
inline __device__ tcnn::vec4 read_rgba(
const tcnn::vec2& uv,
const tcnn::ivec2& resolution,
const void* pixels,
const uint32_t img = 0 // optional, default works same as before
) {
// ---------------------------------------------
// 1. Get pixel address from uv + resolution
// ---------------------------------------------
const uint64_t idx = pixel_idx(uv, resolution, img);
const uint32_t rgba = static_cast<const uint32_t*>(pixels)[idx]; // packed 0xAARRGGBB
// ---------------------------------------------
// 2. Masked pixel → skip (-1 = INVALID)
// ---------------------------------------------
if (rgba == 0x00FF00FFu) return {-1.f, -1.f, -1.f, -1.f};
// ---------------------------------------------
// 3. Extract channels [0–255] → float [0–1]
// ---------------------------------------------
const float r = static_cast<float>((rgba >> 0) & 0xFF) * (1.f / 255.f);
const float g = static_cast<float>((rgba >> 8) & 0xFF) * (1.f / 255.f);
const float b = static_cast<float>((rgba >> 16) & 0xFF) * (1.f / 255.f);
const float a = static_cast<float>((rgba >> 24) & 0xFF) * (1.f / 255.f);
return {srgb_to_linear(r) * a,
srgb_to_linear(g) * a,
srgb_to_linear(b) * a,
a};
}

2.8 Determine Maximum Mip Level for Training#

float max_level = max_level_rand_training ? (random_val(rng) * 2.0f) : 1.0f; // Multiply by 2 to ensure 50% of training is at max level

It seems that NeRF Synthetic training always uses max_level = 1.0f because max_level_rand_training is false by default.

2.8.1 Base Version#

float max_level = 1.0f; // default

As mentioned, NeRF Synthetic training does not use random mip levels. It’s safe to assume max_level = 1.0f always.

2.9 Get Transform with Rolling Shutter and Motion Blur#

float motionblur_time = random_val(rng);
...
const mat4x3 xform = get_xform_given_rolling_shutter(training_xforms[img], metadata[img].rolling_shutter, uv, motionblur_time);

Samples a random time in [0,1] for motion blur simulation during training, then computes the camera-to-world transform at that time, accounting for rolling shutter effects.

2.9.1 Why Motion Blur?#

In Instant-NGP, motionblur_time controls sampling along temporal exposure of rolling-shutter or moving scene. Think of it like simulating a camera where the shutter isn’t instantaneous—different rays observe the world at slightly different times during the frame capture.

2.10 CUDA Function get_xform_given_rolling_shutter#

Ray ray_unnormalized;
const Ray* rays_in_unnormalized = metadata[img].rays;
if (rays_in_unnormalized)
{
// Rays have been explicitly supplied. Read them.
ray_unnormalized = rays_in_unnormalized[pix_idx];
/* DEBUG - compare the stored rays to the computed ones
const mat4x3 xform = get_xform_given_rolling_shutter(training_xforms[img], metadata[img].rolling_shutter,
uv, 0.f); Ray ray2; ray2.o = xform[3]; ray2.d = f_theta_distortion(uv, principal_point, lens); ray2.d =
(xform.block<3, 3>(0, 0) * ray2.d).normalized(); if (i==1000) { printf("\n%d uv %0.3f,%0.3f pixel
%0.2f,%0.2f transform from [%0.5f %0.5f %0.5f] to [%0.5f %0.5f %0.5f]\n" " origin [%0.5f %0.5f %0.5f] vs
[%0.5f %0.5f %0.5f]\n" " direction [%0.5f %0.5f %0.5f] vs [%0.5f %0.5f %0.5f]\n" , img,uv.x, uv.y,
uv.x*resolution.x, uv.y*resolution.y,
training_xforms[img].start[3].x,training_xforms[img].start[3].y,training_xforms[img].start[3].z,
training_xforms[img].end[3].x,training_xforms[img].end[3].y,training_xforms[img].end[3].z,
ray_unnormalized.o.x,ray_unnormalized.o.y,ray_unnormalized.o.z,
ray2.o.x,ray2.o.y,ray2.o.z,
ray_unnormalized.d.x,ray_unnormalized.d.y,ray_unnormalized.d.z,
ray2.d.x,ray2.d.y,ray2.d.z);
}
*/
}
else
{
ray_unnormalized = uv_to_ray(0, uv, resolution, focal_length, xform, principal_point, vec3(0.0f), 0.0f,
1.0f, 0.0f, {}, {}, lens, distortion);
if (!ray_unnormalized.is_valid())
{
ray_unnormalized = {xform[3], xform[2]};
}
}

Some information may be outdated