diff --git a/Cargo.lock b/Cargo.lock index ca88699e..ea254551 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1341,7 +1341,7 @@ dependencies = [ [[package]] name = "bevy_gaussian_splatting" -version = "5.2.0" +version = "5.2.1" dependencies = [ "base64 0.22.1", "bevy 0.16.1", diff --git a/Cargo.toml b/Cargo.toml index d9dfb7d7..197ad135 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "bevy_gaussian_splatting" description = "bevy gaussian splatting render pipeline plugin" -version = "5.2.0" +version = "5.2.1" edition = "2024" rust-version = "1.85.0" authors = ["mosure "] diff --git a/src/render/gaussian.wgsl b/src/render/gaussian.wgsl index dee3ca0e..e1c05d40 100644 --- a/src/render/gaussian.wgsl +++ b/src/render/gaussian.wgsl @@ -165,6 +165,24 @@ #endif +fn world_to_local_direction(ray_direction_world: vec3, transform: mat4x4) -> vec3 { + let basis = mat3x3( + transform[0].xyz, + transform[1].xyz, + transform[2].xyz, + ); + let basis_x = normalize(basis[0]); + let basis_y = normalize(basis[1]); + let basis_z = normalize(basis[2]); + + let local = vec3( + dot(basis_x, ray_direction_world), + dot(basis_y, ray_direction_world), + dot(basis_z, ray_direction_world), + ); + + return normalize(local); +} @vertex fn vs_points( @builtin(instance_index) instance_index: u32, @@ -300,12 +318,13 @@ fn vs_points( // TODO: RASTERIZE_ACCELERATION #ifdef RASTERIZE_CLASSIFICATION - let ray_direction = normalize(transformed_position - view.world_position); + let ray_direction_world = normalize(transformed_position - view.world_position); + let ray_direction_local = world_to_local_direction(ray_direction_world, gaussian_uniforms.transform); #ifdef GAUSSIAN_3D_STRUCTURE - rgb = get_color(splat_index, ray_direction); + rgb = get_color(splat_index, ray_direction_local); #else ifdef GAUSSIAN_4D - rgb = get_color(splat_index, gaussian_4d.dir_t, ray_direction); + rgb = get_color(splat_index, gaussian_4d.dir_t, ray_direction_local); #endif rgb = class_to_rgb( @@ -391,13 +410,13 @@ fn vs_points( rgb = base_color * scaled_mag; #else ifdef RASTERIZE_COLOR // TODO: verify color benefit for ray_direction computed at quad verticies instead of gaussian center (same as current complexity) - // TODO: why doesn't Transform rotation change SH color? - let ray_direction = normalize(transformed_position - view.world_position); + let ray_direction_world = normalize(transformed_position - view.world_position); + let ray_direction_local = world_to_local_direction(ray_direction_world, gaussian_uniforms.transform); #ifdef GAUSSIAN_3D_STRUCTURE - rgb = get_color(splat_index, ray_direction); + rgb = get_color(splat_index, ray_direction_local); #else ifdef GAUSSIAN_4D - rgb = get_color(splat_index, gaussian_4d.dir_t, ray_direction); + rgb = get_color(splat_index, gaussian_4d.dir_t, ray_direction_local); #endif #endif diff --git a/src/sort/radix.rs b/src/sort/radix.rs index 20b2f437..5774cc7f 100644 --- a/src/sort/radix.rs +++ b/src/sort/radix.rs @@ -425,6 +425,7 @@ where settings, ) in gaussian_clouds.iter() { if settings.sort_mode != SortMode::Radix { + commands.entity(entity).remove::(); continue; } diff --git a/src/sort/radix.wgsl b/src/sort/radix.wgsl index f0dca0bd..60dbc08f 100644 --- a/src/sort/radix.wgsl +++ b/src/sort/radix.wgsl @@ -195,19 +195,23 @@ fn radix_sort_c( workgroupBarrier(); // --- Step 3: Determine deterministic global base for each digit --- - if (tid < #{RADIX_BASE}u) { - let count = local_digit_counts[tid]; - let tile_count = max((gaussian_uniforms.count + tile_size - 1u) / tile_size, 1u); - let expected = sorting_pass_index * tile_count + workgroup_id.y; + let tile_count = max((gaussian_uniforms.count + tile_size - 1u) / tile_size, 1u); + let expected_ticket = sorting_pass_index * tile_count + workgroup_id.y; + // Acquire a per-tile ticket so we only serialize once per tile instead of once per digit. + if (tid == 0u) { loop { - let head = atomicLoad(&sorting.digit_tile_head[tid]); - if (head == expected) { - let exchange = atomicCompareExchangeWeak(&sorting.digit_tile_head[tid], expected, expected + 1u); + let head = atomicLoad(&sorting.assignment_counter); + if (head == expected_ticket) { + let exchange = atomicCompareExchangeWeak(&sorting.assignment_counter, expected_ticket, expected_ticket + 1u); if (exchange.exchanged) { break; } } } + } + workgroupBarrier(); + if (tid < #{RADIX_BASE}u) { + let count = local_digit_counts[tid]; let base = atomicAdd(&sorting.digit_histogram[sorting_pass_index][tid], count); digit_global_base_ws[tid] = base; } @@ -238,7 +242,6 @@ fn radix_sort_c( } } } - if (sorting_pass_index == #{RADIX_DIGIT_PLACES}u - 1u && tid == 0u) { atomicStore(&draw_indirect.instance_count, gaussian_uniforms.count); }