这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "bevy_gaussian_splatting"
description = "bevy gaussian splatting render pipeline plugin"
version = "5.2.0"
version = "5.2.1"
edition = "2024"
rust-version = "1.85.0"
authors = ["mosure <mitchell@mosure.me>"]
Expand Down
33 changes: 26 additions & 7 deletions src/render/gaussian.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,24 @@
#endif


fn world_to_local_direction(ray_direction_world: vec3<f32>, transform: mat4x4<f32>) -> vec3<f32> {
let basis = mat3x3<f32>(
transform[0].xyz,
transform[1].xyz,
transform[2].xyz,
);
let basis_x = normalize(basis[0]);
let basis_y = normalize(basis[1]);
let basis_z = normalize(basis[2]);

let local = vec3<f32>(
dot(basis_x, ray_direction_world),
dot(basis_y, ray_direction_world),
dot(basis_z, ray_direction_world),
);

return normalize(local);
}
@vertex
fn vs_points(
@builtin(instance_index) instance_index: u32,
Expand Down Expand Up @@ -300,12 +318,13 @@ fn vs_points(

// TODO: RASTERIZE_ACCELERATION
#ifdef RASTERIZE_CLASSIFICATION
let ray_direction = normalize(transformed_position - view.world_position);
let ray_direction_world = normalize(transformed_position - view.world_position);
let ray_direction_local = world_to_local_direction(ray_direction_world, gaussian_uniforms.transform);

#ifdef GAUSSIAN_3D_STRUCTURE
rgb = get_color(splat_index, ray_direction);
rgb = get_color(splat_index, ray_direction_local);
#else ifdef GAUSSIAN_4D
rgb = get_color(splat_index, gaussian_4d.dir_t, ray_direction);
rgb = get_color(splat_index, gaussian_4d.dir_t, ray_direction_local);
#endif

rgb = class_to_rgb(
Expand Down Expand Up @@ -391,13 +410,13 @@ fn vs_points(
rgb = base_color * scaled_mag;
#else ifdef RASTERIZE_COLOR
// TODO: verify color benefit for ray_direction computed at quad verticies instead of gaussian center (same as current complexity)
// TODO: why doesn't Transform rotation change SH color?
let ray_direction = normalize(transformed_position - view.world_position);
let ray_direction_world = normalize(transformed_position - view.world_position);
let ray_direction_local = world_to_local_direction(ray_direction_world, gaussian_uniforms.transform);

#ifdef GAUSSIAN_3D_STRUCTURE
rgb = get_color(splat_index, ray_direction);
rgb = get_color(splat_index, ray_direction_local);
#else ifdef GAUSSIAN_4D
rgb = get_color(splat_index, gaussian_4d.dir_t, ray_direction);
rgb = get_color(splat_index, gaussian_4d.dir_t, ray_direction_local);
#endif
#endif

Expand Down
1 change: 1 addition & 0 deletions src/sort/radix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ where
settings,
) in gaussian_clouds.iter() {
if settings.sort_mode != SortMode::Radix {
commands.entity(entity).remove::<RadixBindGroup>();
continue;
}

Expand Down
19 changes: 11 additions & 8 deletions src/sort/radix.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,23 @@ fn radix_sort_c(
workgroupBarrier();

// --- Step 3: Determine deterministic global base for each digit ---
if (tid < #{RADIX_BASE}u) {
let count = local_digit_counts[tid];
let tile_count = max((gaussian_uniforms.count + tile_size - 1u) / tile_size, 1u);
let expected = sorting_pass_index * tile_count + workgroup_id.y;
let tile_count = max((gaussian_uniforms.count + tile_size - 1u) / tile_size, 1u);
let expected_ticket = sorting_pass_index * tile_count + workgroup_id.y;

// Acquire a per-tile ticket so we only serialize once per tile instead of once per digit.
if (tid == 0u) {
loop {
let head = atomicLoad(&sorting.digit_tile_head[tid]);
if (head == expected) {
let exchange = atomicCompareExchangeWeak(&sorting.digit_tile_head[tid], expected, expected + 1u);
let head = atomicLoad(&sorting.assignment_counter);
if (head == expected_ticket) {
let exchange = atomicCompareExchangeWeak(&sorting.assignment_counter, expected_ticket, expected_ticket + 1u);
if (exchange.exchanged) { break; }
}
}
}
workgroupBarrier();

if (tid < #{RADIX_BASE}u) {
let count = local_digit_counts[tid];
let base = atomicAdd(&sorting.digit_histogram[sorting_pass_index][tid], count);
digit_global_base_ws[tid] = base;
}
Expand Down Expand Up @@ -238,7 +242,6 @@ fn radix_sort_c(
}
}
}

if (sorting_pass_index == #{RADIX_DIGIT_PLACES}u - 1u && tid == 0u) {
atomicStore(&draw_indirect.instance_count, gaussian_uniforms.count);
}
Expand Down
Loading