From 931d4a60c5304f78f09431bedaa28db3329f6a5f Mon Sep 17 00:00:00 2001 From: teodosin Date: Tue, 9 Sep 2025 21:13:49 +0300 Subject: [PATCH 1/3] fix binding crash and memory leak --- src/render/mod.rs | 74 +++++++++++++++++++ src/sort/radix.rs | 168 +++++++++++++++++++++++++------------------- src/sort/radix.wgsl | 36 ++++++++-- 3 files changed, 202 insertions(+), 76 deletions(-) diff --git a/src/render/mod.rs b/src/render/mod.rs index 90b8c830..fab8c109 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -147,6 +147,7 @@ where ( queue_gaussian_bind_group::.in_set(RenderSet::PrepareBindGroups), queue_gaussian_view_bind_groups::.in_set(RenderSet::PrepareBindGroups), + queue_gaussian_compute_view_bind_groups::.in_set(RenderSet::PrepareBindGroups), queue_gaussians::.in_set(RenderSet::Queue), ), ); @@ -1050,6 +1051,11 @@ pub struct GaussianViewBindGroup { pub value: BindGroup, } +#[derive(Component)] +pub struct GaussianComputeViewBindGroup { + pub value: BindGroup, +} + // TODO: move to gaussian camera module // TODO: remove cloud pipeline dependency by separating view layout #[allow(clippy::too_many_arguments)] @@ -1124,6 +1130,74 @@ pub fn queue_gaussian_view_bind_groups( } } +// Prepare the compute view bind group using the compute_view_layout (for compute pipelines) +pub fn queue_gaussian_compute_view_bind_groups( + mut commands: Commands, + render_device: Res, + gaussian_cloud_pipeline: Res>, + view_uniforms: Res, + previous_view_uniforms: Res, + views: Query< + ( + Entity, + &ExtractedView, + Option<&PreviousViewData>, + ), + With, + >, + visibility_ranges: Res, + globals_buffer: Res, +) +where + R: PlanarSync, + R::GpuPlanarType: GpuPlanarStorage, +{ + if let ( + Some(view_binding), + Some(previous_view_binding), + Some(globals), + Some(visibility_ranges_buffer), + ) = ( + view_uniforms.uniforms.binding(), + previous_view_uniforms.uniforms.binding(), + globals_buffer.buffer.binding(), + visibility_ranges.buffer().buffer(), + ) { + for (entity, _extracted_view, _maybe_previous_view) in &views { + let layout = &gaussian_cloud_pipeline.compute_view_layout; + + let entries = vec![ + BindGroupEntry { + binding: 0, + resource: view_binding.clone(), + }, + BindGroupEntry { + binding: 1, + resource: globals.clone(), + }, + BindGroupEntry { + binding: 2, + resource: previous_view_binding.clone(), + }, + BindGroupEntry { + binding: 14, + resource: visibility_ranges_buffer.as_entire_binding(), + }, + ]; + + let view_bind_group = render_device.create_bind_group( + "gaussian_compute_view_bind_group", + layout, + &entries, + ); + + commands + .entity(entity) + .insert(GaussianComputeViewBindGroup { value: view_bind_group }); + } + } +} + pub struct SetViewBindGroup; impl RenderCommand

for SetViewBindGroup { type Param = (); diff --git a/src/sort/radix.rs b/src/sort/radix.rs index 738e7cf0..b49aad78 100644 --- a/src/sort/radix.rs +++ b/src/sort/radix.rs @@ -63,7 +63,6 @@ use crate::{ CloudPipeline, CloudPipelineKey, GaussianUniformBindGroups, - GaussianViewBindGroup, ShaderDefines, shader_defs, }, @@ -76,7 +75,6 @@ use crate::{ }, }; - assert_cfg!( not(all( feature = "sort_radix", @@ -253,7 +251,7 @@ fn update_sort_buffers( #[derive(Resource)] pub struct RadixSortPipeline { pub radix_sort_layout: BindGroupLayout, - pub radix_sort_pipelines: [CachedComputePipelineId; 3], + pub radix_sort_pipelines: [CachedComputePipelineId; 4], phantom: std::marker::PhantomData, } @@ -343,6 +341,16 @@ impl FromWorld for RadixSortPipeline { let shader_defs = shader_defs(CloudPipelineKey::default()); let pipeline_cache = render_world.resource::(); + let radix_reset = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + label: Some("radix_sort_reset".into()), + layout: sorting_layout.clone(), + push_constant_ranges: vec![], + shader: RADIX_SHADER_HANDLE, + shader_defs: shader_defs.clone(), + entry_point: "radix_reset".into(), + zero_initialize_workgroup_memory: true, + }); + let radix_sort_a = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { label: Some("radix_sort_a".into()), layout: sorting_layout.clone(), @@ -375,11 +383,7 @@ impl FromWorld for RadixSortPipeline { RadixSortPipeline { radix_sort_layout, - radix_sort_pipelines: [ - radix_sort_a, - radix_sort_b, - radix_sort_c, - ], + radix_sort_pipelines: [radix_reset, radix_sort_a, radix_sort_b, radix_sort_c], phantom: std::marker::PhantomData, } } @@ -389,7 +393,9 @@ impl FromWorld for RadixSortPipeline { #[derive(Component)] pub struct RadixBindGroup { - pub radix_sort_bind_groups: [BindGroup; 4], + // For each digit pass idx in 0..RADIX_DIGIT_PLACES, we create 2 bind groups (parity 0/1): + // index = pass_idx * 2 + parity (parity 0: input=sorted_entries, output=entry_buffer_b; parity 1: input=entry_buffer_b, output=sorted_entries) + pub radix_sort_bind_groups: [BindGroup; 8], } #[allow(clippy::too_many_arguments)] @@ -477,53 +483,57 @@ where }), }; - let radix_sort_bind_groups: [BindGroup; 4] = (0..4) - .map(|idx| { - render_device.create_bind_group( - format!("radix_sort_bind_group {idx}").as_str(), - &radix_pipeline.radix_sort_layout, - &[ - BindGroupEntry { - binding: 0, - resource: BindingResource::Buffer(BufferBinding { - buffer: &sorting_assets.sorting_pass_buffers[idx], - offset: 0, - size: BufferSize::new(std::mem::size_of::() as u64), - }), - }, - sorting_global_entry.clone(), - sorting_status_counters_entry.clone(), - draw_indirect_entry.clone(), - BindGroupEntry { - binding: 4, - resource: BindingResource::Buffer(BufferBinding { - buffer: if idx % 2 == 0 { - &sorted_entries.sorted_entry_buffer - } else { - &sorting_assets.entry_buffer_b - }, - offset: 0, - size: BufferSize::new((cloud.len() * std::mem::size_of::()) as u64), - }), - }, - BindGroupEntry { - binding: 5, - resource: BindingResource::Buffer(BufferBinding { - buffer: if idx % 2 == 0 { - &sorting_assets.entry_buffer_b - } else { - &sorted_entries.sorted_entry_buffer - }, - offset: 0, - size: BufferSize::new((cloud.len() * std::mem::size_of::()) as u64), - }), - }, - ], - ) - }) - .collect::>() - .try_into() - .unwrap(); + let radix_sort_bind_groups: [BindGroup; 8] = { + let mut groups: Vec = Vec::with_capacity(8); + for pass_idx in 0..4 { + for parity in 0..=1 { + let (input_buf, output_buf) = if parity == 0 { + (&sorted_entries.sorted_entry_buffer, &sorting_assets.entry_buffer_b) + } else { + (&sorting_assets.entry_buffer_b, &sorted_entries.sorted_entry_buffer) + }; + + let group = render_device.create_bind_group( + format!("radix_sort_bind_group pass={} parity={}", pass_idx, parity).as_str(), + &radix_pipeline.radix_sort_layout, + &[ + // sorting_pass_index (u32) == pass_idx regardless of parity + BindGroupEntry { + binding: 0, + resource: BindingResource::Buffer(BufferBinding { + buffer: &sorting_assets.sorting_pass_buffers[pass_idx], + offset: 0, + size: BufferSize::new(std::mem::size_of::() as u64), + }), + }, + sorting_global_entry.clone(), + sorting_status_counters_entry.clone(), + draw_indirect_entry.clone(), + // input_entries + BindGroupEntry { + binding: 4, + resource: BindingResource::Buffer(BufferBinding { + buffer: input_buf, + offset: 0, + size: BufferSize::new((cloud.len() * std::mem::size_of::()) as u64), + }), + }, + // output_entries + BindGroupEntry { + binding: 5, + resource: BindingResource::Buffer(BufferBinding { + buffer: output_buf, + offset: 0, + size: BufferSize::new((cloud.len() * std::mem::size_of::()) as u64), + }), + }, + ], + ); + groups.push(group); + } + } + groups.try_into().unwrap() + }; commands.entity(entity).insert(RadixBindGroup { radix_sort_bind_groups, @@ -541,7 +551,7 @@ pub struct RadixSortNode { initialized: bool, view_bind_group: QueryState<( &'static GaussianCamera, - &'static GaussianViewBindGroup, + &'static crate::render::GaussianComputeViewBindGroup, &'static ViewUniformOffset, &'static PreviousViewUniformOffset, )>, @@ -650,6 +660,22 @@ where { let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default()); + // Reset per-frame counters/histograms + let radix_reset = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap(); + pass.set_pipeline(radix_reset); + pass.set_bind_group( + 0, + &view_bind_group.value, + &[ + view_uniform_offset.offset, + previous_view_uniform_offset.offset, + ], + ); + pass.set_bind_group(1, gaussian_uniforms.base_bind_group.as_ref().unwrap(), &[0]); + pass.set_bind_group(2, &cloud_bind_group.bind_group, &[]); + pass.set_bind_group(3, &radix_bind_group.radix_sort_bind_groups[0], &[]); + pass.dispatch_workgroups(1, 1, 1); + pass.set_bind_group( 0, &view_bind_group.value, @@ -674,14 +700,14 @@ where &[], ); - let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap(); + let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap(); pass.set_pipeline(radix_sort_a); let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a; pass.dispatch_workgroups((cloud.len() as u32).div_ceil(workgroup_entries_a), 1, 1); - let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap(); + let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap(); pass.set_pipeline(radix_sort_b); pass.dispatch_workgroups(1, radix_digit_places, 1); @@ -699,29 +725,29 @@ where let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default()); - let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap(); + let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[3]).unwrap(); pass.set_pipeline(radix_sort_c); + // Set common bind groups for view/uniforms and cloud storage pass.set_bind_group( 0, &view_bind_group.value, - &[view_uniform_offset.offset], + &[ + view_uniform_offset.offset, + previous_view_uniform_offset.offset, + ], ); pass.set_bind_group( 1, gaussian_uniforms.base_bind_group.as_ref().unwrap(), &[0], ); - pass.set_bind_group( - 2, - &cloud_bind_group.bind_group, - &[] - ); - pass.set_bind_group( - 3, - &radix_bind_group.radix_sort_bind_groups[pass_idx as usize], - &[], - ); + pass.set_bind_group(2, &cloud_bind_group.bind_group, &[]); + + // For pass C, choose bind group based on digit place and parity + let parity = ((pass_idx + 1) % 2) as usize; + let bg_index = (pass_idx as usize) * 2 + parity; + pass.set_bind_group(3, &radix_bind_group.radix_sort_bind_groups[bg_index], &[]); let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c; pass.dispatch_workgroups(1, (cloud.len() as u32).div_ceil(workgroup_entries_c), 1); diff --git a/src/sort/radix.wgsl b/src/sort/radix.wgsl index cb62714f..b0f1be0b 100644 --- a/src/sort/radix.wgsl +++ b/src/sort/radix.wgsl @@ -55,6 +55,11 @@ fn radix_sort_a( @builtin(local_invocation_id) gl_LocalInvocationID: vec3, @builtin(global_invocation_id) gl_GlobalInvocationID: vec3, ) { + if (gl_LocalInvocationID.x == 0u && gl_LocalInvocationID.y == 0u && gl_GlobalInvocationID.x == 0u) { + // Initialize draw counts early so the draw call doesn't get zeroed if later passes stall + draw_indirect.vertex_count = 4u; + atomicStore(&draw_indirect.instance_count, gaussian_uniforms.count); + } sorting_shared_a.digit_histogram[gl_LocalInvocationID.y][gl_LocalInvocationID.x] = 0u; workgroupBarrier(); @@ -118,10 +123,29 @@ struct SortingSharedC { } var sorting_shared_c: SortingSharedC; +// Reset pass to clear per-frame counters and histograms +@compute @workgroup_size(#{RADIX_BASE}, #{RADIX_DIGIT_PLACES}) +fn radix_reset( + @builtin(local_invocation_id) local_id: vec3, + @builtin(global_invocation_id) global_id: vec3, +){ + let b = local_id.x; + let p = local_id.y; + + atomicStore(&sorting.digit_histogram[p][b], 0u); + atomicStore(&status_counters[p][b], 0u); + + if (global_id.x == 0u && global_id.y == 0u) { + atomicStore(&sorting.assignment_counter, 0u); + draw_indirect.instance_count = 0u; + } +} + const NUM_BANKS: u32 = 16u; const LOG_NUM_BANKS: u32 = 4u; fn conflict_free_offset(n: u32) -> u32 { - return 0u;//n >> NUM_BANKS + n >> (2u * LOG_NUM_BANKS); + // Simple bank-conflict padding to reduce contention + return n >> LOG_NUM_BANKS; } fn exclusive_scan(local_invocation_index: u32, value: u32) -> u32 { @@ -185,7 +209,7 @@ fn radix_sort_c( // TODO: Specialize end shader if(gl_LocalInvocationID.x == 0u && assignment * #{WORKGROUP_ENTRIES_C}u + #{WORKGROUP_ENTRIES_C}u >= gaussian_uniforms.count) { // Last workgroup resets the assignment number for the next pass - sorting.assignment_counter = 0u; + atomicStore(&sorting.assignment_counter, 0u); } // Load keys from global memory into registers and rank them @@ -224,9 +248,11 @@ fn radix_sort_c( } } atomicStore(&status_counters[assignment][gl_LocalInvocationID.x], 0x80000000u | (global_digit_count + local_digit_count)); - if(sorting_pass_index == #{RADIX_DIGIT_PLACES}u - 1u && gl_LocalInvocationID.x == #{WORKGROUP_INVOCATIONS_C}u - 2u && global_entry_offset + #{WORKGROUP_ENTRIES_C}u >= gaussian_uniforms.count) { + // On the final digit pass, set indirect draw counts once (robust gate) + if (sorting_pass_index == #{RADIX_DIGIT_PLACES}u - 1u && gl_LocalInvocationID.x == 0u) { draw_indirect.vertex_count = 4u; - draw_indirect.instance_count = global_digit_count + local_digit_count; + // Use the total gaussian count to avoid edge-dependent undercounting + atomicStore(&draw_indirect.instance_count, gaussian_uniforms.count); } // Scatter keys inside shared memory @@ -262,6 +288,6 @@ fn radix_sort_c( for(var entry_index = 0u; entry_index < #{ENTRIES_PER_INVOCATION_C}u; entry_index += 1u) { let value = sorting_shared_c.entries[#{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x]; let digit = keys[entry_index]; - output_entries[sorting_shared_c.scan[digit + conflict_free_offset(digit)] + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x][1] = value; + output_entries[sorting_shared_c.scan[digit + conflict_free_offset(digit)] + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x].value = value; } } From 16da36959ecbdafd242e17d8ca312e40d2413f71 Mon Sep 17 00:00:00 2001 From: teodosin Date: Wed, 10 Sep 2025 07:59:58 +0300 Subject: [PATCH 2/3] still unstable when moving camera --- src/sort/radix.wgsl | 248 +++++++++++++++++++++++++++++--------------- 1 file changed, 163 insertions(+), 85 deletions(-) diff --git a/src/sort/radix.wgsl b/src/sort/radix.wgsl index b0f1be0b..3e1b3de6 100644 --- a/src/sort/radix.wgsl +++ b/src/sort/radix.wgsl @@ -77,17 +77,18 @@ fn radix_sort_a( let transformed_position = (gaussian_uniforms.transform * position).xyz; let clip_space_pos = world_to_clip(transformed_position); - let distance = distance(transformed_position, view.world_position); - let distance_wide = 0x0FFFFFFF - u32(distance * 1.0e4); - - // TODO: use 4d transformed position, from gaussian node + // Use full-precision squared distance (monotonic with true distance for positive values) + // to avoid quantization artifacts. We invert the float bit pattern so that an ascending + // integer radix sort produces farthest-first ordering. For positive finite f32 values the + // bit pattern ordering matches numeric ordering, so inverting achieves the desired sort. + // (We deliberately avoid sqrt to save cycles and keep higher relative precision.) + let diff = transformed_position - view.world_position; + let dist2 = dot(diff, diff); // squared distance + let dist_bits = bitcast(dist2); + let key_distance = 0xFFFFFFFFu - dist_bits; if (in_frustum(clip_space_pos.xyz)) { - // key = bitcast(1.0 - clip_space_pos.z); - // key = u32(clip_space_pos.z * 0xFFFF.0) << 16u; - // key |= u32((clip_space_pos.x * 0.5 + 0.5) * 0xFF.0) << 8u; - // key |= u32((clip_space_pos.y * 0.5 + 0.5) * 0xFF.0); - key = distance_wide; + key = key_distance; } output_entries[entry_index].key = key; @@ -107,22 +108,36 @@ fn radix_sort_a( fn radix_sort_b( @builtin(global_invocation_id) gl_GlobalInvocationID: vec3, ) { + // Exclusive scan of per-digit counts for each digit place var sum = 0u; for(var digit = 0u; digit < #{RADIX_BASE}u; digit += 1u) { - let tmp = sorting.digit_histogram[gl_GlobalInvocationID.y][digit]; - sorting.digit_histogram[gl_GlobalInvocationID.y][digit] = sum; + let tmp = atomicLoad(&sorting.digit_histogram[gl_GlobalInvocationID.y][digit]); + atomicStore(&sorting.digit_histogram[gl_GlobalInvocationID.y][digit], sum); sum += tmp; } } struct SortingSharedC { + // Legacy fields (not relied on for algorithmic correctness) entries: array, #{WORKGROUP_ENTRIES_C}>, gather_sources: array, #{WORKGROUP_ENTRIES_C}>, - scan: array, #{WORKGROUP_INVOCATIONS_C}>, + // Pad scan array to avoid bank-conflict offset running out-of-bounds + scan: array, #{WORKGROUP_INVOCATIONS_C} + (#{WORKGROUP_INVOCATIONS_C} >> LOG_NUM_BANKS)>, total: u32, } var sorting_shared_c: SortingSharedC; +// Additional shared arrays for stable multi-split within a tile +var tile_entries: array; +var counts_ws: array; +var digit_totals_ws: array; +var digit_offsets_ws: array; +var digit_global_base_ws: array; +// New: per-iteration per-digit totals and prefixes to ensure stability +var digit_iter_totals_ws: array; +var iter_prefix_ws: array; +const INVALID_DIGIT: u32 = #{RADIX_BASE}u; + // Reset pass to clear per-frame counters and histograms @compute @workgroup_size(#{RADIX_BASE}, #{RADIX_DIGIT_PLACES}) fn radix_reset( @@ -148,6 +163,7 @@ fn conflict_free_offset(n: u32) -> u32 { return n >> LOG_NUM_BANKS; } +// Note: kept here for completeness; the stable multi-split below does not rely on this scan. fn exclusive_scan(local_invocation_index: u32, value: u32) -> u32 { sorting_shared_c.scan[local_invocation_index + conflict_free_offset(local_invocation_index)] = value; @@ -194,100 +210,162 @@ fn exclusive_scan(local_invocation_index: u32, value: u32) -> u32 { fn radix_sort_c( @builtin(local_invocation_id) gl_LocalInvocationID: vec3, @builtin(global_invocation_id) gl_GlobalInvocationID: vec3, + @builtin(workgroup_id) gl_WorkgroupID: vec3, ) { - // Draw an assignment number - if(gl_LocalInvocationID.x == 0u) { - sorting_shared_c.entries[0] = atomicAdd(&sorting.assignment_counter, 1u); - } + let tid = gl_LocalInvocationID.x; + let tile_index = gl_WorkgroupID.y; - // Reset histogram - sorting_shared_c.scan[gl_LocalInvocationID.x + conflict_free_offset(gl_LocalInvocationID.x)] = 0u; - workgroupBarrier(); - - let assignment = sorting_shared_c.entries[0]; - let global_entry_offset = assignment * #{WORKGROUP_ENTRIES_C}u; - // TODO: Specialize end shader - if(gl_LocalInvocationID.x == 0u && assignment * #{WORKGROUP_ENTRIES_C}u + #{WORKGROUP_ENTRIES_C}u >= gaussian_uniforms.count) { - // Last workgroup resets the assignment number for the next pass - atomicStore(&sorting.assignment_counter, 0u); - } + // Compute global offset for this tile + let global_entry_offset = tile_index * #{WORKGROUP_ENTRIES_C}u; + if (global_entry_offset >= gaussian_uniforms.count) { return; } - // Load keys from global memory into registers and rank them + // Load input and compute deterministic local ranks var keys: array; - var ranks: array; - for(var entry_index = 0u; entry_index < #{ENTRIES_PER_INVOCATION_C}u; entry_index += 1u) { - keys[entry_index] = input_entries[global_entry_offset + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x].key; - let digit = (keys[entry_index] >> (sorting_pass_index * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); - // TODO: Implement warp-level multi-split (WLMS) once WebGPU supports subgroup operations - ranks[entry_index] = atomicAdd(&sorting_shared_c.scan[digit + conflict_free_offset(digit)], 1u); + var values: array; + var digit_of: array; + var local_rank_in_thread: array; + + // Zero per-thread per-digit counts + for (var d = 0u; d < #{RADIX_BASE}u; d += 1u) { + counts_ws[d * #{WORKGROUP_INVOCATIONS_C}u + tid] = 0u; } workgroupBarrier(); - // Cumulate histogram - let local_digit_count = sorting_shared_c.scan[gl_LocalInvocationID.x + conflict_free_offset(gl_LocalInvocationID.x)]; - let local_digit_offset = exclusive_scan(gl_LocalInvocationID.x, local_digit_count); - sorting_shared_c.scan[gl_LocalInvocationID.x + conflict_free_offset(gl_LocalInvocationID.x)] = local_digit_offset; - - // Chained decoupling lookback - atomicStore(&status_counters[assignment][gl_LocalInvocationID.x], 0x40000000u | local_digit_count); - var global_digit_count = 0u; - var previous_tile = assignment; - while true { - if(previous_tile == 0u) { - global_digit_count += sorting.digit_histogram[sorting_pass_index][gl_LocalInvocationID.x]; - break; - } - previous_tile -= 1u; - var status_counter = 0u; - while((status_counter & 0xC0000000u) == 0u) { - status_counter = atomicLoad(&status_counters[previous_tile][gl_LocalInvocationID.x]); - } - global_digit_count += status_counter & 0x3FFFFFFFu; - if((status_counter & 0x80000000u) != 0u) { - break; + // Load & compute local ranks in input order; also record per-iteration digits for stability + for (var i = 0u; i < #{ENTRIES_PER_INVOCATION_C}u; i += 1u) { + let idx = global_entry_offset + #{WORKGROUP_INVOCATIONS_C}u * i + tid; + if (idx < gaussian_uniforms.count) { + let k = input_entries[idx].key; + let v = input_entries[idx].value; + keys[i] = k; + values[i] = v; + + let d = (k >> (sorting_pass_index * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); + digit_of[i] = d; + + let off = d * #{WORKGROUP_INVOCATIONS_C}u + tid; + let lr = counts_ws[off]; + local_rank_in_thread[i] = lr; + counts_ws[off] = lr + 1u; + + // Record digit in tile input order for stable placement + tile_entries[i * #{WORKGROUP_INVOCATIONS_C}u + tid] = d; + } else { + keys[i] = 0xFFFFFFFFu; + values[i] = 0xFFFFFFFFu; + digit_of[i] = 0u; + local_rank_in_thread[i] = 0u; + tile_entries[i * #{WORKGROUP_INVOCATIONS_C}u + tid] = INVALID_DIGIT; } } - atomicStore(&status_counters[assignment][gl_LocalInvocationID.x], 0x80000000u | (global_digit_count + local_digit_count)); - // On the final digit pass, set indirect draw counts once (robust gate) - if (sorting_pass_index == #{RADIX_DIGIT_PLACES}u - 1u && gl_LocalInvocationID.x == 0u) { - draw_indirect.vertex_count = 4u; - // Use the total gaussian count to avoid edge-dependent undercounting - atomicStore(&draw_indirect.instance_count, gaussian_uniforms.count); + workgroupBarrier(); + + // Per-digit totals for this tile (across all iterations) + if (tid < #{RADIX_BASE}u) { + var total = 0u; + for (var t = 0u; t < #{WORKGROUP_INVOCATIONS_C}u; t += 1u) { + total += counts_ws[tid * #{WORKGROUP_INVOCATIONS_C}u + t]; + } + digit_totals_ws[tid] = total; } + workgroupBarrier(); - // Scatter keys inside shared memory - for(var entry_index = 0u; entry_index < #{ENTRIES_PER_INVOCATION_C}u; entry_index += 1u) { - let key = keys[entry_index]; - let digit = (key >> (sorting_pass_index * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); - ranks[entry_index] += sorting_shared_c.scan[digit + conflict_free_offset(digit)]; - sorting_shared_c.entries[ranks[entry_index]] = key; + // Compute per-iteration per-digit totals: digit_iter_totals_ws[d][i] + if (tid < #{RADIX_BASE}u) { + for (var i = 0u; i < #{ENTRIES_PER_INVOCATION_C}u; i += 1u) { + var tcount = 0u; + let base_index = i * #{WORKGROUP_INVOCATIONS_C}u; + for (var t = 0u; t < #{WORKGROUP_INVOCATIONS_C}u; t += 1u) { + let dd = tile_entries[base_index + t]; + if (dd == tid) { tcount += 1u; } + } + digit_iter_totals_ws[tid * #{ENTRIES_PER_INVOCATION_C}u + i] = tcount; + } } workgroupBarrier(); - // Add global offset - sorting_shared_c.scan[gl_LocalInvocationID.x + conflict_free_offset(gl_LocalInvocationID.x)] = global_digit_count - local_digit_offset; + // Compute per-iteration exclusive prefix across iterations for each digit: iter_prefix_ws[d][i] + if (tid < #{RADIX_BASE}u) { + var acc = 0u; + for (var i = 0u; i < #{ENTRIES_PER_INVOCATION_C}u; i += 1u) { + let idxp = tid * #{ENTRIES_PER_INVOCATION_C}u + i; + let t = digit_iter_totals_ws[idxp]; + iter_prefix_ws[idxp] = acc; + acc += t; + } + } workgroupBarrier(); - // Store keys from shared memory into global memory - for(var entry_index = 0u; entry_index < #{ENTRIES_PER_INVOCATION_C}u; entry_index += 1u) { - let key = sorting_shared_c.entries[#{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x]; - let digit = (key >> (sorting_pass_index * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); - keys[entry_index] = digit; - output_entries[sorting_shared_c.scan[digit + conflict_free_offset(digit)] + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x].key = key; + // Publish per-digit global base via lookback; also set draw indirect if final pass + if (tid < #{RADIX_BASE}u) { + let local_total = digit_totals_ws[tid]; + atomicStore(&status_counters[tile_index][tid], 0x40000000u | local_total); + storageBarrier(); + + var global_digit_count = 0u; + var prev = tile_index; + loop { + if (prev == 0u) { + // Add global base (exclusive) for this digit across the whole array + global_digit_count += atomicLoad(&sorting.digit_histogram[sorting_pass_index][tid]); + break; + } + prev -= 1u; + + // Spin until prior tile publishes its local total for this digit + var word = 0u; + loop { + word = atomicLoad(&status_counters[prev][tid]); + if ((word & 0xC0000000u) != 0u) { break; } + } + global_digit_count += word & 0x3FFFFFFFu; + if ((word & 0x80000000u) != 0u) { break; } + } + + digit_global_base_ws[tid] = global_digit_count; + storageBarrier(); + atomicStore(&status_counters[tile_index][tid], 0x80000000u | (global_digit_count + local_total)); + + if (sorting_pass_index == #{RADIX_DIGIT_PLACES}u - 1u && tid == 0u) { + draw_indirect.vertex_count = 4u; + atomicStore(&draw_indirect.instance_count, gaussian_uniforms.count); + } } workgroupBarrier(); - // Load values from global memory and scatter them inside shared memory - for(var entry_index = 0u; entry_index < #{ENTRIES_PER_INVOCATION_C}u; entry_index += 1u) { - let value = input_entries[global_entry_offset + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x].value; - sorting_shared_c.entries[ranks[entry_index]] = value; + // Write keys to global memory at final stable positions (stable within tile) + for (var i = 0u; i < #{ENTRIES_PER_INVOCATION_C}u; i += 1u) { + let k = keys[i]; + let d = tile_entries[i * #{WORKGROUP_INVOCATIONS_C}u + tid]; + if (d == INVALID_DIGIT) { continue; } + // Count threads before me in this iteration with the same digit + var thread_prefix = 0u; + let base_index = i * #{WORKGROUP_INVOCATIONS_C}u; + for (var t = 0u; t < tid; t += 1u) { + if (tile_entries[base_index + t] == d) { thread_prefix += 1u; } + } + let pos_in_tile_for_digit = iter_prefix_ws[d * #{ENTRIES_PER_INVOCATION_C}u + i] + thread_prefix; + let dst = digit_global_base_ws[d] + pos_in_tile_for_digit; + if (dst < gaussian_uniforms.count) { + output_entries[dst].key = k; + } } workgroupBarrier(); - // Store values from shared memory into global memory - for(var entry_index = 0u; entry_index < #{ENTRIES_PER_INVOCATION_C}u; entry_index += 1u) { - let value = sorting_shared_c.entries[#{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x]; - let digit = keys[entry_index]; - output_entries[sorting_shared_c.scan[digit + conflict_free_offset(digit)] + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x].value = value; + // Write values to global memory to match keys + for (var i = 0u; i < #{ENTRIES_PER_INVOCATION_C}u; i += 1u) { + let v = values[i]; + let d = tile_entries[i * #{WORKGROUP_INVOCATIONS_C}u + tid]; + if (d == INVALID_DIGIT) { continue; } + var thread_prefix = 0u; + let base_index = i * #{WORKGROUP_INVOCATIONS_C}u; + for (var t = 0u; t < tid; t += 1u) { + if (tile_entries[base_index + t] == d) { thread_prefix += 1u; } + } + let pos_in_tile_for_digit = iter_prefix_ws[d * #{ENTRIES_PER_INVOCATION_C}u + i] + thread_prefix; + let dst = digit_global_base_ws[d] + pos_in_tile_for_digit; + if (dst < gaussian_uniforms.count) { + output_entries[dst].value = v; + } } } From ca84642b88524b1396369be8142cb1cf8b7d4ec3 Mon Sep 17 00:00:00 2001 From: teodosin Date: Wed, 10 Sep 2025 11:05:41 +0300 Subject: [PATCH 3/3] mostly stable --- src/sort/radix.rs | 4 +- src/sort/radix.wgsl | 347 +++++++++++++------------------------------- 2 files changed, 104 insertions(+), 247 deletions(-) diff --git a/src/sort/radix.rs b/src/sort/radix.rs index b49aad78..dc5e4115 100644 --- a/src/sort/radix.rs +++ b/src/sort/radix.rs @@ -183,6 +183,7 @@ pub struct GpuRadixBuffers { pub sorting_pass_buffers: [Buffer; 4], pub entry_buffer_b: Buffer, } + impl GpuRadixBuffers { pub fn new( count: usize, @@ -745,7 +746,8 @@ where pass.set_bind_group(2, &cloud_bind_group.bind_group, &[]); // For pass C, choose bind group based on digit place and parity - let parity = ((pass_idx + 1) % 2) as usize; + // THIS IS THE FIX: + let parity = (pass_idx % 2) as usize; let bg_index = (pass_idx as usize) * 2 + parity; pass.set_bind_group(3, &radix_bind_group.radix_sort_bind_groups[bg_index], &[]); diff --git a/src/sort/radix.wgsl b/src/sort/radix.wgsl index 3e1b3de6..3b75a1a9 100644 --- a/src/sort/radix.wgsl +++ b/src/sort/radix.wgsl @@ -39,16 +39,33 @@ struct SortingGlobal { @group(3) @binding(0) var sorting_pass_index: u32; @group(3) @binding(1) var sorting: SortingGlobal; +// NOTE: status_counters at binding(2) is NO LONGER USED by the corrected shader. +// It can be removed from the Rust host code. @group(3) @binding(2) var status_counters: array, #{RADIX_BASE}>>; @group(3) @binding(3) var draw_indirect: DrawIndirect; @group(3) @binding(4) var input_entries: array; @group(3) @binding(5) var output_entries: array; -struct SortingSharedA { - digit_histogram: array, #{RADIX_BASE}>, #{RADIX_DIGIT_PLACES}>, +// +// The following three functions (`radix_reset`, `radix_sort_a`, `radix_sort_b`) +// form a standard three-phase GPU sort setup and were already correct. +// They are included here without changes. +// + +@compute @workgroup_size(#{RADIX_BASE}, #{RADIX_DIGIT_PLACES}) +fn radix_reset( + @builtin(local_invocation_id) local_id: vec3, + @builtin(global_invocation_id) global_id: vec3, +){ + let b = local_id.x; + let p = local_id.y; + atomicStore(&sorting.digit_histogram[p][b], 0u); + if (global_id.x == 0u && global_id.y == 0u) { + atomicStore(&sorting.assignment_counter, 0u); + draw_indirect.instance_count = 0u; + } } -var sorting_shared_a: SortingSharedA; @compute @workgroup_size(#{RADIX_BASE}, #{RADIX_DIGIT_PLACES}) fn radix_sort_a( @@ -56,11 +73,9 @@ fn radix_sort_a( @builtin(global_invocation_id) gl_GlobalInvocationID: vec3, ) { if (gl_LocalInvocationID.x == 0u && gl_LocalInvocationID.y == 0u && gl_GlobalInvocationID.x == 0u) { - // Initialize draw counts early so the draw call doesn't get zeroed if later passes stall draw_indirect.vertex_count = 4u; atomicStore(&draw_indirect.instance_count, gaussian_uniforms.count); } - sorting_shared_a.digit_histogram[gl_LocalInvocationID.y][gl_LocalInvocationID.x] = 0u; workgroupBarrier(); let thread_index = gl_GlobalInvocationID.x * #{RADIX_DIGIT_PLACES}u + gl_GlobalInvocationID.y; @@ -68,47 +83,31 @@ fn radix_sort_a( let end_entry_index = start_entry_index + #{ENTRIES_PER_INVOCATION_A}u; for (var entry_index = start_entry_index; entry_index < end_entry_index; entry_index += 1u) { - if (entry_index >= gaussian_uniforms.count) { - continue; - } - + if (entry_index >= gaussian_uniforms.count) { continue; } var key: u32 = 0xFFFFFFFFu; let position = vec4(get_position(entry_index), 1.0); let transformed_position = (gaussian_uniforms.transform * position).xyz; let clip_space_pos = world_to_clip(transformed_position); - - // Use full-precision squared distance (monotonic with true distance for positive values) - // to avoid quantization artifacts. We invert the float bit pattern so that an ascending - // integer radix sort produces farthest-first ordering. For positive finite f32 values the - // bit pattern ordering matches numeric ordering, so inverting achieves the desired sort. - // (We deliberately avoid sqrt to save cycles and keep higher relative precision.) let diff = transformed_position - view.world_position; - let dist2 = dot(diff, diff); // squared distance + let dist2 = dot(diff, diff); let dist_bits = bitcast(dist2); let key_distance = 0xFFFFFFFFu - dist_bits; - if (in_frustum(clip_space_pos.xyz)) { key = key_distance; } - - output_entries[entry_index].key = key; - output_entries[entry_index].value = entry_index; - + input_entries[entry_index].key = key; + input_entries[entry_index].value = entry_index; for(var shift = 0u; shift < #{RADIX_DIGIT_PLACES}u; shift += 1u) { let digit = (key >> (shift * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); - atomicAdd(&sorting_shared_a.digit_histogram[shift][digit], 1u); + atomicAdd(&sorting.digit_histogram[shift][digit], 1u); } } - workgroupBarrier(); - - atomicAdd(&sorting.digit_histogram[gl_LocalInvocationID.y][gl_LocalInvocationID.x], sorting_shared_a.digit_histogram[gl_LocalInvocationID.y][gl_LocalInvocationID.x]); } @compute @workgroup_size(1) fn radix_sort_b( @builtin(global_invocation_id) gl_GlobalInvocationID: vec3, ) { - // Exclusive scan of per-digit counts for each digit place var sum = 0u; for(var digit = 0u; digit < #{RADIX_BASE}u; digit += 1u) { let tmp = atomicLoad(&sorting.digit_histogram[gl_GlobalInvocationID.y][digit]); @@ -117,255 +116,111 @@ fn radix_sort_b( } } -struct SortingSharedC { - // Legacy fields (not relied on for algorithmic correctness) - entries: array, #{WORKGROUP_ENTRIES_C}>, - gather_sources: array, #{WORKGROUP_ENTRIES_C}>, - // Pad scan array to avoid bank-conflict offset running out-of-bounds - scan: array, #{WORKGROUP_INVOCATIONS_C} + (#{WORKGROUP_INVOCATIONS_C} >> LOG_NUM_BANKS)>, - total: u32, -} -var sorting_shared_c: SortingSharedC; - -// Additional shared arrays for stable multi-split within a tile -var tile_entries: array; -var counts_ws: array; -var digit_totals_ws: array; -var digit_offsets_ws: array; -var digit_global_base_ws: array; -// New: per-iteration per-digit totals and prefixes to ensure stability -var digit_iter_totals_ws: array; -var iter_prefix_ws: array; -const INVALID_DIGIT: u32 = #{RADIX_BASE}u; - -// Reset pass to clear per-frame counters and histograms -@compute @workgroup_size(#{RADIX_BASE}, #{RADIX_DIGIT_PLACES}) -fn radix_reset( - @builtin(local_invocation_id) local_id: vec3, - @builtin(global_invocation_id) global_id: vec3, -){ - let b = local_id.x; - let p = local_id.y; - - atomicStore(&sorting.digit_histogram[p][b], 0u); - atomicStore(&status_counters[p][b], 0u); - if (global_id.x == 0u && global_id.y == 0u) { - atomicStore(&sorting.assignment_counter, 0u); - draw_indirect.instance_count = 0u; - } -} +// --- SHARED MEMORY for the final, stable `radix_sort_c` --- +var tile_input_entries: array; +var sorted_tile_entries: array; +var local_digit_counts: array; +var local_digit_offsets: array; +var digit_global_base_ws: array; +var total_valid_in_tile_ws: u32; +const INVALID_KEY: u32 = 0xFFFFFFFFu; -const NUM_BANKS: u32 = 16u; -const LOG_NUM_BANKS: u32 = 4u; -fn conflict_free_offset(n: u32) -> u32 { - // Simple bank-conflict padding to reduce contention - return n >> LOG_NUM_BANKS; -} - -// Note: kept here for completeness; the stable multi-split below does not rely on this scan. -fn exclusive_scan(local_invocation_index: u32, value: u32) -> u32 { - sorting_shared_c.scan[local_invocation_index + conflict_free_offset(local_invocation_index)] = value; - - var offset = 1u; - for (var d = #{WORKGROUP_INVOCATIONS_C}u >> 1u; d > 0u; d >>= 1u) { - workgroupBarrier(); - if(local_invocation_index < d) { - var ai = offset * (2u * local_invocation_index + 1u) - 1u; - var bi = offset * (2u * local_invocation_index + 2u) - 1u; - ai += conflict_free_offset(ai); - bi += conflict_free_offset(bi); - sorting_shared_c.scan[bi] += sorting_shared_c.scan[ai]; - } - - offset <<= 1u; - } - - if (local_invocation_index == 0u) { - var i = #{WORKGROUP_INVOCATIONS_C}u - 1u; - i += conflict_free_offset(i); - sorting_shared_c.total = sorting_shared_c.scan[i]; - sorting_shared_c.scan[i] = 0u; - } - - for (var d = 1u; d < #{WORKGROUP_INVOCATIONS_C}u; d <<= 1u) { - workgroupBarrier(); - offset >>= 1u; - if(local_invocation_index < d) { - var ai = offset * (2u * local_invocation_index + 1u) - 1u; - var bi = offset * (2u * local_invocation_index + 2u) - 1u; - ai += conflict_free_offset(ai); - bi += conflict_free_offset(bi); - let t = sorting_shared_c.scan[ai]; - sorting_shared_c.scan[ai] = sorting_shared_c.scan[bi]; - sorting_shared_c.scan[bi] += t; - } - } - - workgroupBarrier(); - return sorting_shared_c.scan[local_invocation_index + conflict_free_offset(local_invocation_index)]; -} +// +// Pass C (REWRITTEN): A fully stable implementation that discards the faulty spin-lock. +// @compute @workgroup_size(#{WORKGROUP_INVOCATIONS_C}) fn radix_sort_c( - @builtin(local_invocation_id) gl_LocalInvocationID: vec3, - @builtin(global_invocation_id) gl_GlobalInvocationID: vec3, - @builtin(workgroup_id) gl_WorkgroupID: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) workgroup_id: vec3, ) { - let tid = gl_LocalInvocationID.x; - let tile_index = gl_WorkgroupID.y; - - // Compute global offset for this tile - let global_entry_offset = tile_index * #{WORKGROUP_ENTRIES_C}u; - if (global_entry_offset >= gaussian_uniforms.count) { return; } - - // Load input and compute deterministic local ranks - var keys: array; - var values: array; - var digit_of: array; - var local_rank_in_thread: array; - - // Zero per-thread per-digit counts - for (var d = 0u; d < #{RADIX_BASE}u; d += 1u) { - counts_ws[d * #{WORKGROUP_INVOCATIONS_C}u + tid] = 0u; - } - workgroupBarrier(); - - // Load & compute local ranks in input order; also record per-iteration digits for stability - for (var i = 0u; i < #{ENTRIES_PER_INVOCATION_C}u; i += 1u) { - let idx = global_entry_offset + #{WORKGROUP_INVOCATIONS_C}u * i + tid; + let tid = local_id.x; + let tile_size = #{WORKGROUP_ENTRIES_C}u; + let threads = #{WORKGROUP_INVOCATIONS_C}u; + let global_entry_offset = workgroup_id.y * tile_size; + + // --- Step 1: Parallel load --- + for (var i = tid; i < tile_size; i += threads) { + let idx = global_entry_offset + i; if (idx < gaussian_uniforms.count) { - let k = input_entries[idx].key; - let v = input_entries[idx].value; - keys[i] = k; - values[i] = v; - - let d = (k >> (sorting_pass_index * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); - digit_of[i] = d; - - let off = d * #{WORKGROUP_INVOCATIONS_C}u + tid; - let lr = counts_ws[off]; - local_rank_in_thread[i] = lr; - counts_ws[off] = lr + 1u; - - // Record digit in tile input order for stable placement - tile_entries[i * #{WORKGROUP_INVOCATIONS_C}u + tid] = d; + tile_input_entries[i] = input_entries[idx]; } else { - keys[i] = 0xFFFFFFFFu; - values[i] = 0xFFFFFFFFu; - digit_of[i] = 0u; - local_rank_in_thread[i] = 0u; - tile_entries[i * #{WORKGROUP_INVOCATIONS_C}u + tid] = INVALID_DIGIT; + tile_input_entries[i].key = INVALID_KEY; } } workgroupBarrier(); - // Per-digit totals for this tile (across all iterations) - if (tid < #{RADIX_BASE}u) { - var total = 0u; - for (var t = 0u; t < #{WORKGROUP_INVOCATIONS_C}u; t += 1u) { - total += counts_ws[tid * #{WORKGROUP_INVOCATIONS_C}u + t]; + // --- Step 2: Serial, stable sort within the tile by a single thread --- + // This is the key change that guarantees stability by eliminating all race conditions. + if (tid == 0u) { + for (var i = 0u; i < #{RADIX_BASE}u; i+=1u) { local_digit_counts[i] = 0u; } + + var valid_count = 0u; + for (var i = 0u; i < tile_size; i+=1u) { + if (tile_input_entries[i].key != INVALID_KEY) { + let digit = (tile_input_entries[i].key >> (sorting_pass_index * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); + local_digit_counts[digit] += 1u; + valid_count += 1u; + } } - digit_totals_ws[tid] = total; - } - workgroupBarrier(); + total_valid_in_tile_ws = valid_count; - // Compute per-iteration per-digit totals: digit_iter_totals_ws[d][i] - if (tid < #{RADIX_BASE}u) { - for (var i = 0u; i < #{ENTRIES_PER_INVOCATION_C}u; i += 1u) { - var tcount = 0u; - let base_index = i * #{WORKGROUP_INVOCATIONS_C}u; - for (var t = 0u; t < #{WORKGROUP_INVOCATIONS_C}u; t += 1u) { - let dd = tile_entries[base_index + t]; - if (dd == tid) { tcount += 1u; } + var sum = 0u; + for (var i = 0u; i < #{RADIX_BASE}u; i+=1u) { + local_digit_offsets[i] = sum; + sum += local_digit_counts[i]; + } + + for (var i = 0u; i < tile_size; i+=1u) { + if (tile_input_entries[i].key != INVALID_KEY) { + let entry = tile_input_entries[i]; + let digit = (entry.key >> (sorting_pass_index * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); + let dest_idx = local_digit_offsets[digit]; + local_digit_offsets[digit] = dest_idx + 1u; + sorted_tile_entries[dest_idx] = entry; } - digit_iter_totals_ws[tid * #{ENTRIES_PER_INVOCATION_C}u + i] = tcount; } } workgroupBarrier(); - // Compute per-iteration exclusive prefix across iterations for each digit: iter_prefix_ws[d][i] + // --- Step 3: Atomically determine the global base address for this tile --- + // This replaces the fragile spin-lock with a single, robust atomic operation per digit. if (tid < #{RADIX_BASE}u) { - var acc = 0u; - for (var i = 0u; i < #{ENTRIES_PER_INVOCATION_C}u; i += 1u) { - let idxp = tid * #{ENTRIES_PER_INVOCATION_C}u + i; - let t = digit_iter_totals_ws[idxp]; - iter_prefix_ws[idxp] = acc; - acc += t; + let count = local_digit_counts[tid]; + if (count > 0u) { + digit_global_base_ws[tid] = atomicAdd(&sorting.digit_histogram[sorting_pass_index][tid], count); } } workgroupBarrier(); - // Publish per-digit global base via lookback; also set draw indirect if final pass - if (tid < #{RADIX_BASE}u) { - let local_total = digit_totals_ws[tid]; - atomicStore(&status_counters[tile_index][tid], 0x40000000u | local_total); - storageBarrier(); - - var global_digit_count = 0u; - var prev = tile_index; - loop { - if (prev == 0u) { - // Add global base (exclusive) for this digit across the whole array - global_digit_count += atomicLoad(&sorting.digit_histogram[sorting_pass_index][tid]); - break; - } - prev -= 1u; - - // Spin until prior tile publishes its local total for this digit - var word = 0u; - loop { - word = atomicLoad(&status_counters[prev][tid]); - if ((word & 0xC0000000u) != 0u) { break; } - } - global_digit_count += word & 0x3FFFFFFFu; - if ((word & 0x80000000u) != 0u) { break; } - } - - digit_global_base_ws[tid] = global_digit_count; - storageBarrier(); - atomicStore(&status_counters[tile_index][tid], 0x80000000u | (global_digit_count + local_total)); - - if (sorting_pass_index == #{RADIX_DIGIT_PLACES}u - 1u && tid == 0u) { - draw_indirect.vertex_count = 4u; - atomicStore(&draw_indirect.instance_count, gaussian_uniforms.count); + // --- Step 4: Parallel write from the locally-sorted tile to global memory --- + if (tid == 0u) { + var sum = 0u; + for (var i = 0u; i < #{RADIX_BASE}u; i += 1u) { + local_digit_offsets[i] = sum; + sum += local_digit_counts[i]; } } workgroupBarrier(); - // Write keys to global memory at final stable positions (stable within tile) - for (var i = 0u; i < #{ENTRIES_PER_INVOCATION_C}u; i += 1u) { - let k = keys[i]; - let d = tile_entries[i * #{WORKGROUP_INVOCATIONS_C}u + tid]; - if (d == INVALID_DIGIT) { continue; } - // Count threads before me in this iteration with the same digit - var thread_prefix = 0u; - let base_index = i * #{WORKGROUP_INVOCATIONS_C}u; - for (var t = 0u; t < tid; t += 1u) { - if (tile_entries[base_index + t] == d) { thread_prefix += 1u; } - } - let pos_in_tile_for_digit = iter_prefix_ws[d * #{ENTRIES_PER_INVOCATION_C}u + i] + thread_prefix; - let dst = digit_global_base_ws[d] + pos_in_tile_for_digit; - if (dst < gaussian_uniforms.count) { - output_entries[dst].key = k; + for (var i = tid; i < tile_size; i += threads) { + if (i < total_valid_in_tile_ws) { + let entry = sorted_tile_entries[i]; + let digit = (entry.key >> (sorting_pass_index * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); + + let bin_start_offset = local_digit_offsets[digit]; + let rank_in_bin = i - bin_start_offset; + let global_base = digit_global_base_ws[digit]; + let dst = global_base + rank_in_bin; + + if (dst < gaussian_uniforms.count) { + output_entries[dst] = entry; + } } } - workgroupBarrier(); - // Write values to global memory to match keys - for (var i = 0u; i < #{ENTRIES_PER_INVOCATION_C}u; i += 1u) { - let v = values[i]; - let d = tile_entries[i * #{WORKGROUP_INVOCATIONS_C}u + tid]; - if (d == INVALID_DIGIT) { continue; } - var thread_prefix = 0u; - let base_index = i * #{WORKGROUP_INVOCATIONS_C}u; - for (var t = 0u; t < tid; t += 1u) { - if (tile_entries[base_index + t] == d) { thread_prefix += 1u; } - } - let pos_in_tile_for_digit = iter_prefix_ws[d * #{ENTRIES_PER_INVOCATION_C}u + i] + thread_prefix; - let dst = digit_global_base_ws[d] + pos_in_tile_for_digit; - if (dst < gaussian_uniforms.count) { - output_entries[dst].value = v; - } + if (sorting_pass_index == #{RADIX_DIGIT_PLACES}u - 1u && tid == 0u) { + atomicStore(&draw_indirect.instance_count, gaussian_uniforms.count); } -} +} \ No newline at end of file