这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ where
(
queue_gaussian_bind_group::<R>.in_set(RenderSet::PrepareBindGroups),
queue_gaussian_view_bind_groups::<R>.in_set(RenderSet::PrepareBindGroups),
queue_gaussian_compute_view_bind_groups::<R>.in_set(RenderSet::PrepareBindGroups),
queue_gaussians::<R>.in_set(RenderSet::Queue),
),
);
Expand Down Expand Up @@ -1050,6 +1051,11 @@ pub struct GaussianViewBindGroup {
pub value: BindGroup,
}

#[derive(Component)]
pub struct GaussianComputeViewBindGroup {
pub value: BindGroup,
}

// TODO: move to gaussian camera module
// TODO: remove cloud pipeline dependency by separating view layout
#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -1124,6 +1130,74 @@ pub fn queue_gaussian_view_bind_groups<R: PlanarSync>(
}
}

// Prepare the compute view bind group using the compute_view_layout (for compute pipelines)
pub fn queue_gaussian_compute_view_bind_groups<R: PlanarSync>(
mut commands: Commands,
render_device: Res<RenderDevice>,
gaussian_cloud_pipeline: Res<CloudPipeline<R>>,
view_uniforms: Res<ViewUniforms>,
previous_view_uniforms: Res<PreviousViewUniforms>,
views: Query<
(
Entity,
&ExtractedView,
Option<&PreviousViewData>,
),
With<GaussianCamera>,
>,
visibility_ranges: Res<RenderVisibilityRanges>,
globals_buffer: Res<GlobalsBuffer>,
)
where
R: PlanarSync,
R::GpuPlanarType: GpuPlanarStorage,
{
if let (
Some(view_binding),
Some(previous_view_binding),
Some(globals),
Some(visibility_ranges_buffer),
) = (
view_uniforms.uniforms.binding(),
previous_view_uniforms.uniforms.binding(),
globals_buffer.buffer.binding(),
visibility_ranges.buffer().buffer(),
) {
for (entity, _extracted_view, _maybe_previous_view) in &views {
let layout = &gaussian_cloud_pipeline.compute_view_layout;

let entries = vec![
BindGroupEntry {
binding: 0,
resource: view_binding.clone(),
},
BindGroupEntry {
binding: 1,
resource: globals.clone(),
},
BindGroupEntry {
binding: 2,
resource: previous_view_binding.clone(),
},
BindGroupEntry {
binding: 14,
resource: visibility_ranges_buffer.as_entire_binding(),
},
];

let view_bind_group = render_device.create_bind_group(
"gaussian_compute_view_bind_group",
layout,
&entries,
);

commands
.entity(entity)
.insert(GaussianComputeViewBindGroup { value: view_bind_group });
}
}
}

pub struct SetViewBindGroup<const I: usize>;
impl<P: PhaseItem, const I: usize> RenderCommand<P> for SetViewBindGroup<I> {
type Param = ();
Expand Down
170 changes: 99 additions & 71 deletions src/sort/radix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ use crate::{
CloudPipeline,
CloudPipelineKey,
GaussianUniformBindGroups,
GaussianViewBindGroup,
ShaderDefines,
shader_defs,
},
Expand All @@ -76,7 +75,6 @@ use crate::{
},
};


assert_cfg!(
not(all(
feature = "sort_radix",
Expand Down Expand Up @@ -185,6 +183,7 @@ pub struct GpuRadixBuffers {
pub sorting_pass_buffers: [Buffer; 4],
pub entry_buffer_b: Buffer,
}

impl GpuRadixBuffers {
pub fn new(
count: usize,
Expand Down Expand Up @@ -253,7 +252,7 @@ fn update_sort_buffers<R: PlanarSync>(
#[derive(Resource)]
pub struct RadixSortPipeline<R: PlanarSync> {
pub radix_sort_layout: BindGroupLayout,
pub radix_sort_pipelines: [CachedComputePipelineId; 3],
pub radix_sort_pipelines: [CachedComputePipelineId; 4],
phantom: std::marker::PhantomData<R>,
}

Expand Down Expand Up @@ -343,6 +342,16 @@ impl<R: PlanarSync> FromWorld for RadixSortPipeline<R> {
let shader_defs = shader_defs(CloudPipelineKey::default());

let pipeline_cache = render_world.resource::<PipelineCache>();
let radix_reset = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_reset".into()),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: RADIX_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
entry_point: "radix_reset".into(),
zero_initialize_workgroup_memory: true,
});

let radix_sort_a = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_a".into()),
layout: sorting_layout.clone(),
Expand Down Expand Up @@ -375,11 +384,7 @@ impl<R: PlanarSync> FromWorld for RadixSortPipeline<R> {

RadixSortPipeline {
radix_sort_layout,
radix_sort_pipelines: [
radix_sort_a,
radix_sort_b,
radix_sort_c,
],
radix_sort_pipelines: [radix_reset, radix_sort_a, radix_sort_b, radix_sort_c],
phantom: std::marker::PhantomData,
}
}
Expand All @@ -389,7 +394,9 @@ impl<R: PlanarSync> FromWorld for RadixSortPipeline<R> {

#[derive(Component)]
pub struct RadixBindGroup {
pub radix_sort_bind_groups: [BindGroup; 4],
// For each digit pass idx in 0..RADIX_DIGIT_PLACES, we create 2 bind groups (parity 0/1):
// index = pass_idx * 2 + parity (parity 0: input=sorted_entries, output=entry_buffer_b; parity 1: input=entry_buffer_b, output=sorted_entries)
pub radix_sort_bind_groups: [BindGroup; 8],
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -477,53 +484,57 @@ where
}),
};

let radix_sort_bind_groups: [BindGroup; 4] = (0..4)
.map(|idx| {
render_device.create_bind_group(
format!("radix_sort_bind_group {idx}").as_str(),
&radix_pipeline.radix_sort_layout,
&[
BindGroupEntry {
binding: 0,
resource: BindingResource::Buffer(BufferBinding {
buffer: &sorting_assets.sorting_pass_buffers[idx],
offset: 0,
size: BufferSize::new(std::mem::size_of::<u32>() as u64),
}),
},
sorting_global_entry.clone(),
sorting_status_counters_entry.clone(),
draw_indirect_entry.clone(),
BindGroupEntry {
binding: 4,
resource: BindingResource::Buffer(BufferBinding {
buffer: if idx % 2 == 0 {
&sorted_entries.sorted_entry_buffer
} else {
&sorting_assets.entry_buffer_b
},
offset: 0,
size: BufferSize::new((cloud.len() * std::mem::size_of::<SortEntry>()) as u64),
}),
},
BindGroupEntry {
binding: 5,
resource: BindingResource::Buffer(BufferBinding {
buffer: if idx % 2 == 0 {
&sorting_assets.entry_buffer_b
} else {
&sorted_entries.sorted_entry_buffer
},
offset: 0,
size: BufferSize::new((cloud.len() * std::mem::size_of::<SortEntry>()) as u64),
}),
},
],
)
})
.collect::<Vec<BindGroup>>()
.try_into()
.unwrap();
let radix_sort_bind_groups: [BindGroup; 8] = {
let mut groups: Vec<BindGroup> = Vec::with_capacity(8);
for pass_idx in 0..4 {
for parity in 0..=1 {
let (input_buf, output_buf) = if parity == 0 {
(&sorted_entries.sorted_entry_buffer, &sorting_assets.entry_buffer_b)
} else {
(&sorting_assets.entry_buffer_b, &sorted_entries.sorted_entry_buffer)
};

let group = render_device.create_bind_group(
format!("radix_sort_bind_group pass={} parity={}", pass_idx, parity).as_str(),
&radix_pipeline.radix_sort_layout,
&[
// sorting_pass_index (u32) == pass_idx regardless of parity
BindGroupEntry {
binding: 0,
resource: BindingResource::Buffer(BufferBinding {
buffer: &sorting_assets.sorting_pass_buffers[pass_idx],
offset: 0,
size: BufferSize::new(std::mem::size_of::<u32>() as u64),
}),
},
sorting_global_entry.clone(),
sorting_status_counters_entry.clone(),
draw_indirect_entry.clone(),
// input_entries
BindGroupEntry {
binding: 4,
resource: BindingResource::Buffer(BufferBinding {
buffer: input_buf,
offset: 0,
size: BufferSize::new((cloud.len() * std::mem::size_of::<SortEntry>()) as u64),
}),
},
// output_entries
BindGroupEntry {
binding: 5,
resource: BindingResource::Buffer(BufferBinding {
buffer: output_buf,
offset: 0,
size: BufferSize::new((cloud.len() * std::mem::size_of::<SortEntry>()) as u64),
}),
},
],
);
groups.push(group);
}
}
groups.try_into().unwrap()
};

commands.entity(entity).insert(RadixBindGroup {
radix_sort_bind_groups,
Expand All @@ -541,7 +552,7 @@ pub struct RadixSortNode<R: PlanarSync> {
initialized: bool,
view_bind_group: QueryState<(
&'static GaussianCamera,
&'static GaussianViewBindGroup,
&'static crate::render::GaussianComputeViewBindGroup,
&'static ViewUniformOffset,
&'static PreviousViewUniformOffset,
)>,
Expand Down Expand Up @@ -650,6 +661,22 @@ where
{
let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

// Reset per-frame counters/histograms
let radix_reset = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap();
pass.set_pipeline(radix_reset);
pass.set_bind_group(
0,
&view_bind_group.value,
&[
view_uniform_offset.offset,
previous_view_uniform_offset.offset,
],
);
pass.set_bind_group(1, gaussian_uniforms.base_bind_group.as_ref().unwrap(), &[0]);
pass.set_bind_group(2, &cloud_bind_group.bind_group, &[]);
pass.set_bind_group(3, &radix_bind_group.radix_sort_bind_groups[0], &[]);
pass.dispatch_workgroups(1, 1, 1);

pass.set_bind_group(
0,
&view_bind_group.value,
Expand All @@ -674,14 +701,14 @@ where
&[],
);

let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap();
let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
pass.set_pipeline(radix_sort_a);

let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
pass.dispatch_workgroups((cloud.len() as u32).div_ceil(workgroup_entries_a), 1, 1);


let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap();
pass.set_pipeline(radix_sort_b);

pass.dispatch_workgroups(1, radix_digit_places, 1);
Expand All @@ -699,29 +726,30 @@ where

let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap();
let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[3]).unwrap();
pass.set_pipeline(radix_sort_c);

// Set common bind groups for view/uniforms and cloud storage
pass.set_bind_group(
0,
&view_bind_group.value,
&[view_uniform_offset.offset],
&[
view_uniform_offset.offset,
previous_view_uniform_offset.offset,
],
);
pass.set_bind_group(
1,
gaussian_uniforms.base_bind_group.as_ref().unwrap(),
&[0],
);
pass.set_bind_group(
2,
&cloud_bind_group.bind_group,
&[]
);
pass.set_bind_group(
3,
&radix_bind_group.radix_sort_bind_groups[pass_idx as usize],
&[],
);
pass.set_bind_group(2, &cloud_bind_group.bind_group, &[]);

// For pass C, choose bind group based on digit place and parity
// THIS IS THE FIX:
let parity = (pass_idx % 2) as usize;
let bg_index = (pass_idx as usize) * 2 + parity;
pass.set_bind_group(3, &radix_bind_group.radix_sort_bind_groups[bg_index], &[]);

let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
pass.dispatch_workgroups(1, (cloud.len() as u32).div_ceil(workgroup_entries_c), 1);
Expand Down
Loading
Loading