diff --git a/.cargo/config.toml b/.cargo/config.toml index 63310040..805f8358 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -2,6 +2,7 @@ [target.wasm32-unknown-unknown] runner = "wasm-server-runner" +rustflags = ["--cfg=web_sys_unstable_apis"] # fix spurious network error on windows diff --git a/.gitignore b/.gitignore index a64aac48..45d282a6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ Cargo.lock **/*.rs.bk *.pdb +*.py *.ply *.gcloud diff --git a/.vscode/launch.json b/.vscode/launch.json index 841cab71..e58836ff 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -29,15 +29,15 @@ { "type": "lldb", "request": "launch", - "name": "Debug executable 'bevy_gaussian_splatting'", + "name": "Debug executable 'viewer'", "cargo": { "args": [ "build", - "--bin=bevy_gaussian_splatting", + "--bin=viewer", "--package=bevy_gaussian_splatting" ], "filter": { - "name": "bevy_gaussian_splatting", + "name": "viewer", "kind": "bin" } }, @@ -50,16 +50,16 @@ { "type": "lldb", "request": "launch", - "name": "Debug unit tests in executable 'bevy_gaussian_splatting'", + "name": "Debug unit tests in executable 'viewer'", "cargo": { "args": [ "test", "--no-run", - "--bin=bevy_gaussian_splatting", + "--bin=viewer", "--package=bevy_gaussian_splatting" ], "filter": { - "name": "bevy_gaussian_splatting", + "name": "viewer", "kind": "bin" } }, diff --git a/Cargo.toml b/Cargo.toml index 3f968c83..b0da72a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "bevy_gaussian_splatting" description = "bevy gaussian splatting render pipeline plugin" -version = "0.1.0" +version = "0.2.0" edition = "2021" authors = ["mosure "] license = "MIT" @@ -10,26 +10,32 @@ categories = ["computer-vision", "graphics", "rendering", "rendering::data-forma homepage = "https://github.com/mosure/bevy_gaussian_splatting" repository = "https://github.com/mosure/bevy_gaussian_splatting" readme = "README.md" -include = ["tools"] exclude = [".devcontainer", ".github", "docs", "dist", "build", "assets", "credits"] -default-run = "bevy_gaussian_splatting" +default-run = "viewer" -# TODO: use minimal bevy features [dependencies] -bevy = "0.11.3" bevy-inspector-egui = "0.20.0" bevy_panorbit_camera = "0.8.0" bincode2 = "2.0.1" bytemuck = "1.14.0" flate2 = "1.0.28" ply-rs = "0.1.3" +rand = "0.8.5" serde = "1.0.189" +wgpu = "0.16.0" [target.'cfg(target_arch = "wasm32")'.dependencies] console_error_panic_hook = "0.1.7" wasm-bindgen = "0.2.87" + +# TODO: use minimal bevy features +[dependencies.bevy] +version = "0.11.3" +default-features = true + + [dependencies.web-sys] version = "0.3.4" features = [ @@ -63,8 +69,8 @@ codegen-units = 1 path = "src/lib.rs" [[bin]] -name = "bevy_gaussian_splatting" -path = "src/main.rs" +name = "viewer" +path = "viewer/viewer.rs" [[bin]] name = "ply_to_gcloud" diff --git a/README.md b/README.md index 2f423a4e..dbb6d291 100644 --- a/README.md +++ b/README.md @@ -12,13 +12,15 @@ bevy gaussian splatting render pipeline plugin -`cargo run -- {path to ply or gcloud.gz file}` +`cargo run -- {path to ply or gcloud file}` ## capabilities - [X] ply to gcloud converter - [X] gcloud and ply asset loaders - [X] bevy gaussian cloud render pipeline +- [ ] wasm support /w [live demo](https://mosure.github.io/bevy_gaussian_splatting) +- [ ] temporal depth sorting - [ ] f16 and f32 gcloud support - [ ] 4D gaussian clouds via morph targets - [ ] bevy_openxr support @@ -43,7 +45,7 @@ fn setup_gaussian_cloud( asset_server: Res, ) { commands.spawn(GaussianSplattingBundle { - cloud: asset_server.load("scenes/icecream.ply"), + cloud: asset_server.load("scenes/icecream.gcloud"), ..Default::default() }); @@ -51,6 +53,17 @@ fn setup_gaussian_cloud( } ``` +## tools + +- [ply to gcloud converter](tools/README.md#ply-to-gcloud-converter) +- [] gaussian cloud training tool + +## wasm support + +to build wasm run: +- `cargo build --target wasm32-unknown-unknown --release` +- `wasm-bindgen --out-dir ./out/ --target web ./target/` + ## compatible bevy versions @@ -64,11 +77,13 @@ fn setup_gaussian_cloud( - [4d gaussians](https://github.com/hustvl/4DGaussians) - [bevy](https://github.com/bevyengine/bevy) - [bevy-hanabi](https://github.com/djeedai/bevy_hanabi) +- [deformable-3d-gaussians](https://github.com/ingra14m/Deformable-3D-Gaussians) - [diff-gaussian-rasterization](https://github.com/graphdeco-inria/diff-gaussian-rasterization) - [dreamgaussian](https://github.com/dreamgaussian/dreamgaussian) - [dynamic-3d-gaussians](https://github.com/JonathonLuiten/Dynamic3DGaussians) - [ewa splatting](https://www.cs.umd.edu/~zwicker/publications/EWASplatting-TVCG02.pdf) - [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) +- [gaussian-splatting-viewer](https://github.com/limacv/GaussianSplattingViewer/tree/main) - [gaussian-splatting-web](https://github.com/cvlab-epfl/gaussian-splatting-web) - [making gaussian splats smaller](https://aras-p.info/blog/2023/09/13/Making-Gaussian-Splats-smaller/) - [onesweep](https://arxiv.org/ftp/arxiv/papers/2206/2206.01784.pdf) diff --git a/assets/scenes/icecream.gcloud b/assets/scenes/icecream.gcloud index 6828e088..884a0c86 100644 Binary files a/assets/scenes/icecream.gcloud and b/assets/scenes/icecream.gcloud differ diff --git a/src/gaussian.rs b/src/gaussian.rs index fcce3701..986d4be7 100644 --- a/src/gaussian.rs +++ b/src/gaussian.rs @@ -1,3 +1,4 @@ +use rand::seq::SliceRandom; use std::{ io::{ BufReader, @@ -41,7 +42,9 @@ const fn num_sh_coefficients(degree: usize) -> usize { } } const SH_DEGREE: usize = 3; -pub const MAX_SH_COEFF_COUNT: usize = num_sh_coefficients(SH_DEGREE) * 3; +pub const SH_CHANNELS: usize = 3; +pub const MAX_SH_COEFF_COUNT_PER_CHANNEL: usize = num_sh_coefficients(SH_DEGREE); +pub const MAX_SH_COEFF_COUNT: usize = MAX_SH_COEFF_COUNT_PER_CHANNEL * SH_CHANNELS; #[derive( Clone, Copy, @@ -95,8 +98,9 @@ where A: serde::de::SeqAccess<'de>, { let mut coefficients = [0.0; MAX_SH_COEFF_COUNT]; - for i in 0..MAX_SH_COEFF_COUNT { - coefficients[i] = seq + + for (i, coefficient) in coefficients.iter_mut().enumerate().take(MAX_SH_COEFF_COUNT) { + *coefficient = seq .next_element()? .ok_or_else(|| serde::de::Error::invalid_length(i, &self))?; } @@ -126,11 +130,9 @@ pub const MAX_SIZE_VARIANCE: f32 = 5.0; // TODO: support f16 gaussian clouds (shader and asset loader) pub struct Gaussian { pub rotation: [f32; 4], - pub position: Vec3, - pub scale: Vec3, - pub opacity: f32, + pub position: [f32; 4], + pub scale_opacity: [f32; 4], pub spherical_harmonic: SphericalHarmonicCoefficients, - padding: f32, } #[derive( @@ -153,9 +155,18 @@ impl GaussianCloud { 0.0, 0.0, ], - position: Vec3::new(0.0, 0.0, 0.0), - scale: Vec3::new(0.5, 0.5, 0.5), - opacity: 0.8, + position: [ + 0.0, + 0.0, + 0.0, + 1.0, + ], + scale_opacity: [ + 0.5, + 0.5, + 0.5, + 0.5, + ], spherical_harmonic: SphericalHarmonicCoefficients{ coefficients: [ 1.0, 0.0, 1.0, @@ -176,16 +187,18 @@ impl GaussianCloud { 0.6, 0.1, 0.2, ], }, - padding: 0.0, }; let mut cloud = GaussianCloud(Vec::new()); for &x in [-0.5, 0.5].iter() { for &y in [-0.5, 0.5].iter() { for &z in [-0.5, 0.5].iter() { - let mut g = origin.clone(); - g.position = Vec3::new(x, y, z); + let mut g = origin; + g.position = [x, y, z, 1.0]; cloud.0.push(g); + + let mut rng = rand::thread_rng(); + cloud.0.last_mut().unwrap().spherical_harmonic.coefficients.shuffle(&mut rng); } } } @@ -235,14 +248,16 @@ impl AssetLoader for GaussianCloudLoader { let cloud = GaussianCloud(ply_cloud); load_context.set_default_asset(LoadedAsset::new(cloud)); - return Ok(()); + + Ok(()) }, Some(ext) if ext == "gcloud" => { let decompressed = GzDecoder::new(bytes); let cloud: GaussianCloud = deserialize_from(decompressed).expect("failed to decode cloud"); load_context.set_default_asset(LoadedAsset::new(cloud)); - return Ok(()); + + Ok(()) }, _ => Ok(()), } diff --git a/src/lib.rs b/src/lib.rs index 45c19b23..6d9bfbc4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,6 +32,7 @@ pub struct GaussianSplattingPlugin; impl Plugin for GaussianSplattingPlugin { fn build(&self, app: &mut App) { // TODO: allow hot reloading of GaussianCloud handle through inspector UI + app.register_type::(); app.register_type::(); app.add_asset::(); app.register_asset_reflect::(); diff --git a/src/ply.rs b/src/ply.rs index 2a64d61a..aba6a883 100644 --- a/src/ply.rs +++ b/src/ply.rs @@ -1,9 +1,6 @@ use std::io::BufRead; -use bevy::{ - asset::Error, - math::Vec3, -}; +use bevy::asset::Error; use ply_rs::{ ply::{ Property, @@ -14,7 +11,9 @@ use ply_rs::{ use crate::gaussian::{ Gaussian, + MAX_SH_COEFF_COUNT_PER_CHANNEL, MAX_SIZE_VARIANCE, + SH_CHANNELS, }; @@ -25,19 +24,16 @@ impl PropertyAccess for Gaussian { fn set_property(&mut self, key: String, property: Property) { match (key.as_ref(), property) { - ("x", Property::Float(v)) => self.position.x = v, - ("y", Property::Float(v)) => self.position.y = v, - ("z", Property::Float(v)) => self.position.z = v, - // ("nx", Property::Float(v)) => self.normal.x = v, - // ("ny", Property::Float(v)) => self.normal.y = v, - // ("nz", Property::Float(v)) => self.normal.z = v, + ("x", Property::Float(v)) => self.position[0] = v, + ("y", Property::Float(v)) => self.position[1] = v, + ("z", Property::Float(v)) => self.position[2] = v, ("f_dc_0", Property::Float(v)) => self.spherical_harmonic.coefficients[0] = v, ("f_dc_1", Property::Float(v)) => self.spherical_harmonic.coefficients[1] = v, ("f_dc_2", Property::Float(v)) => self.spherical_harmonic.coefficients[2] = v, - ("opacity", Property::Float(v)) => self.opacity = 1.0 / (1.0 + (-v).exp()), - ("scale_0", Property::Float(v)) => self.scale.x = v, - ("scale_1", Property::Float(v)) => self.scale.y = v, - ("scale_2", Property::Float(v)) => self.scale.z = v, + ("scale_0", Property::Float(v)) => self.scale_opacity[0] = v, + ("scale_1", Property::Float(v)) => self.scale_opacity[1] = v, + ("scale_2", Property::Float(v)) => self.scale_opacity[2] = v, + ("opacity", Property::Float(v)) => self.scale_opacity[3] = 1.0 / (1.0 + (-v).exp()), ("rot_0", Property::Float(v)) => self.rotation[0] = v, ("rot_1", Property::Float(v)) => self.rotation[1] = v, ("rot_2", Property::Float(v)) => self.rotation[2] = v, @@ -47,7 +43,6 @@ impl PropertyAccess for Gaussian { match i { _ if i + 3 < self.spherical_harmonic.coefficients.len() => { - // TODO: verify this is the correct sh order (packed not planar) self.spherical_harmonic.coefficients[i + 3] = v; }, _ => { }, @@ -65,18 +60,36 @@ pub fn parse_ply(mut reader: &mut dyn BufRead) -> Result, Error> { let mut cloud = Vec::new(); for (_ignore_key, element) in &header.elements { - match element.name.as_ref() { - "vertex" => { cloud = gaussian_parser.read_payload_for_element(&mut reader, &element, &header)?; }, - _ => {}, + if element.name == "vertex" { + cloud = gaussian_parser.read_payload_for_element(&mut reader, element, &header)?; } } for gaussian in &mut cloud { - let mean_scale = (gaussian.scale.x + gaussian.scale.y + gaussian.scale.z) / 3.0; - gaussian.scale = gaussian.scale - .max(Vec3::splat(mean_scale - MAX_SIZE_VARIANCE)) - .min(Vec3::splat(mean_scale + MAX_SIZE_VARIANCE)) - .exp(); + gaussian.position[3] = 1.0; + + let mean_scale = (gaussian.scale_opacity[0] + gaussian.scale_opacity[1] + gaussian.scale_opacity[2]) / 3.0; + for i in 0..3 { + gaussian.scale_opacity[i] = gaussian.scale_opacity[i] + .max(mean_scale - MAX_SIZE_VARIANCE) + .min(mean_scale + MAX_SIZE_VARIANCE) + .exp(); + } + + let sh_src = gaussian.spherical_harmonic.coefficients; + let sh = &mut gaussian.spherical_harmonic.coefficients; + + for (i, sh_src) in sh_src.iter().enumerate().skip(SH_CHANNELS) { + let j = i - SH_CHANNELS; + + let channel = j / (MAX_SH_COEFF_COUNT_PER_CHANNEL - 1); + let coefficient = (j % (MAX_SH_COEFF_COUNT_PER_CHANNEL - 1)) + 1; + + let interleaved_idx = coefficient * SH_CHANNELS + channel; + assert!(interleaved_idx >= SH_CHANNELS); + + sh[interleaved_idx] = *sh_src; + } } Ok(cloud) diff --git a/src/render/gaussian.wgsl b/src/render/gaussian.wgsl index 709e3041..ce434a7d 100644 --- a/src/render/gaussian.wgsl +++ b/src/render/gaussian.wgsl @@ -6,9 +6,8 @@ struct GaussianInput { @location(0) rotation: vec4, - @location(1) position: vec3, - @location(2) scale: vec3, - @location(3) opacity: f32, + @location(1) position: vec4, + @location(2) scale_opacity: vec4, sh: array, }; @@ -25,6 +24,22 @@ struct GaussianUniforms { global_scale: f32, }; +struct DrawIndirect { + vertex_count: u32, + instance_count: atomic, + base_vertex: u32, + base_instance: u32, +} +struct SortingGlobal { + status_counters: array, #{RADIX_BASE}>, #{MAX_TILE_COUNT_C}>, + digit_histogram: array, #{RADIX_BASE}>, #{RADIX_DIGIT_PLACES}>, + assignment_counter: atomic, +} +struct Entry { + key: u32, + value: u32, +} + @group(0) @binding(0) var view: View; @group(0) @binding(1) var globals: Globals; @@ -33,16 +48,260 @@ struct GaussianUniforms { @group(2) @binding(0) var points: array; +@group(3) @binding(0) var sorting_pass_index: u32; +@group(3) @binding(1) var sorting: SortingGlobal; +@group(3) @binding(2) var draw_indirect: DrawIndirect; +@group(3) @binding(3) var input_entries: array; +@group(3) @binding(4) var output_entries: array; +@group(3) @binding(5) var sorted_entries: array; + +struct SortingSharedA { + digit_histogram: array, #{RADIX_BASE}>, #{RADIX_DIGIT_PLACES}>, +} +var sorting_shared_a: SortingSharedA; + +// TODO: resolve flickering (maybe more radix passes?) +@compute @workgroup_size(#{RADIX_BASE}, #{RADIX_DIGIT_PLACES}) +fn radix_sort_a( + @builtin(local_invocation_id) gl_LocalInvocationID: vec3, + @builtin(global_invocation_id) gl_GlobalInvocationID: vec3, +) { + sorting_shared_a.digit_histogram[gl_LocalInvocationID.y][gl_LocalInvocationID.x] = 0u; + workgroupBarrier(); + + let thread_index = gl_GlobalInvocationID.x * #{RADIX_DIGIT_PLACES}u + gl_GlobalInvocationID.y; + let start_entry_index = thread_index * #{ENTRIES_PER_INVOCATION_A}u; + let end_entry_index = start_entry_index + #{ENTRIES_PER_INVOCATION_A}u; + for(var entry_index = start_entry_index; entry_index < end_entry_index; entry_index += 1u) { + if(entry_index >= arrayLength(&points)) { + continue; + } + var key: u32 = 0xFFFFFFFFu; // Stream compaction for frustum culling + let transformed_position = (uniforms.global_transform * points[entry_index].position).xyz; + let clip_space_pos = world_to_clip(transformed_position); + if(in_frustum(clip_space_pos.xyz)) { + // key = bitcast(1.0 - clip_space_pos.z); + key = u32((1.0 - clip_space_pos.z) * 0xFFFF.0) << 16u; + key |= u32((clip_space_pos.x * 0.5 + 0.5) * 0xFF.0) << 8u; + key |= u32((clip_space_pos.y * 0.5 + 0.5) * 0xFF.0); + } + output_entries[entry_index].key = key; + output_entries[entry_index].value = entry_index; + for(var shift = 0u; shift < #{RADIX_DIGIT_PLACES}u; shift += 1u) { + let digit = (key >> (shift * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); + atomicAdd(&sorting_shared_a.digit_histogram[shift][digit], 1u); + } + } + workgroupBarrier(); + + atomicAdd(&sorting.digit_histogram[gl_LocalInvocationID.y][gl_LocalInvocationID.x], sorting_shared_a.digit_histogram[gl_LocalInvocationID.y][gl_LocalInvocationID.x]); +} + +@compute @workgroup_size(1) +fn radix_sort_b( + @builtin(global_invocation_id) gl_GlobalInvocationID: vec3, +) { + var sum = 0u; + for(var digit = 0u; digit < #{RADIX_BASE}u; digit += 1u) { + let tmp = sorting.digit_histogram[gl_GlobalInvocationID.y][digit]; + sorting.digit_histogram[gl_GlobalInvocationID.y][digit] = sum; + sum += tmp; + } +} + +struct SortingSharedC { + entries: array, #{WORKGROUP_ENTRIES_C}>, + gather_sources: array, #{WORKGROUP_ENTRIES_C}>, + scan: array, #{WORKGROUP_INVOCATIONS_C}>, + total: u32, +} +var sorting_shared_c: SortingSharedC; + +const NUM_BANKS: u32 = 16u; +const LOG_NUM_BANKS: u32 = 4u; +fn conflict_free_offset(n: u32) -> u32 { + return n >> NUM_BANKS + n >> (2u * LOG_NUM_BANKS); +} + +fn exclusive_scan(local_invocation_index: u32, value: u32) -> u32 { + sorting_shared_c.scan[local_invocation_index + conflict_free_offset(local_invocation_index)] = value; + + var offset = 1u; + for (var d = #{WORKGROUP_INVOCATIONS_C}u >> 1u; d > 0u; d >>= 1u) { + workgroupBarrier(); + if(local_invocation_index < d) { + var ai = offset * (2u * local_invocation_index + 1u) - 1u; + var bi = offset * (2u * local_invocation_index + 2u) - 1u; + ai += conflict_free_offset(ai); + bi += conflict_free_offset(bi); + sorting_shared_c.scan[bi] += sorting_shared_c.scan[ai]; + } + + offset <<= 1u; + } + + if (local_invocation_index == 0u) { + var i = #{WORKGROUP_INVOCATIONS_C}u - 1u; + i += conflict_free_offset(i); + sorting_shared_c.total = sorting_shared_c.scan[i]; + sorting_shared_c.scan[i] = 0u; + } + + for (var d = 1u; d < #{WORKGROUP_INVOCATIONS_C}u; d <<= 1u) { + workgroupBarrier(); + offset >>= 1u; + if(local_invocation_index < d) { + var ai = offset * (2u * local_invocation_index + 1u) - 1u; + var bi = offset * (2u * local_invocation_index + 2u) - 1u; + ai += conflict_free_offset(ai); + bi += conflict_free_offset(bi); + let t = sorting_shared_c.scan[ai]; + sorting_shared_c.scan[ai] = sorting_shared_c.scan[bi]; + sorting_shared_c.scan[bi] += t; + } + } + + workgroupBarrier(); + return sorting_shared_c.scan[local_invocation_index + conflict_free_offset(local_invocation_index)]; +} + +@compute @workgroup_size(#{WORKGROUP_INVOCATIONS_C}) +fn radix_sort_c( + @builtin(local_invocation_id) gl_LocalInvocationID: vec3, + @builtin(global_invocation_id) gl_GlobalInvocationID: vec3, +) { + // Draw an assignment number + if(gl_LocalInvocationID.x == 0u) { + sorting_shared_c.entries[0] = atomicAdd(&sorting.assignment_counter, 1u); + } + + // Reset histogram + sorting_shared_c.scan[gl_LocalInvocationID.x + conflict_free_offset(gl_LocalInvocationID.x)] = 0u; + workgroupBarrier(); + + let assignment = sorting_shared_c.entries[0]; + let global_entry_offset = assignment * #{WORKGROUP_ENTRIES_C}u; + // TODO: Specialize end shader + if(gl_LocalInvocationID.x == 0u && assignment * #{WORKGROUP_ENTRIES_C}u + #{WORKGROUP_ENTRIES_C}u >= arrayLength(&points)) { + // Last workgroup resets the assignment number for the next pass + sorting.assignment_counter = 0u; + } + + // Load keys from global memory into registers and rank them + var keys: array; + var ranks: array; + for(var entry_index = 0u; entry_index < #{ENTRIES_PER_INVOCATION_C}u; entry_index += 1u) { + keys[entry_index] = input_entries[global_entry_offset + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x][0]; + let digit = (keys[entry_index] >> (sorting_pass_index * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); + // TODO: Implement warp-level multi-split (WLMS) once WebGPU supports subgroup operations + ranks[entry_index] = atomicAdd(&sorting_shared_c.scan[digit + conflict_free_offset(digit)], 1u); + } + workgroupBarrier(); + + // Cumulate histogram + let local_digit_count = sorting_shared_c.scan[gl_LocalInvocationID.x + conflict_free_offset(gl_LocalInvocationID.x)]; + let local_digit_offset = exclusive_scan(gl_LocalInvocationID.x, local_digit_count); + sorting_shared_c.scan[gl_LocalInvocationID.x + conflict_free_offset(gl_LocalInvocationID.x)] = local_digit_offset; + + // Chained decoupling lookback + atomicStore(&sorting.status_counters[assignment][gl_LocalInvocationID.x], 0x40000000u | local_digit_count); + var global_digit_count = 0u; + var previous_tile = assignment; + while true { + if(previous_tile == 0u) { + global_digit_count += sorting.digit_histogram[sorting_pass_index][gl_LocalInvocationID.x]; + break; + } + previous_tile -= 1u; + var status_counter = 0u; + while((status_counter & 0xC0000000u) == 0u) { + status_counter = atomicLoad(&sorting.status_counters[previous_tile][gl_LocalInvocationID.x]); + } + global_digit_count += status_counter & 0x3FFFFFFFu; + if((status_counter & 0x80000000u) != 0u) { + break; + } + } + atomicStore(&sorting.status_counters[assignment][gl_LocalInvocationID.x], 0x80000000u | (global_digit_count + local_digit_count)); + if(sorting_pass_index == #{RADIX_DIGIT_PLACES}u - 1u && gl_LocalInvocationID.x == #{WORKGROUP_INVOCATIONS_C}u - 2u && global_entry_offset + #{WORKGROUP_ENTRIES_C}u >= arrayLength(&points)) { + draw_indirect.vertex_count = 4u; + draw_indirect.instance_count = global_digit_count + local_digit_count; + } + + // Scatter keys inside shared memory + for(var entry_index = 0u; entry_index < #{ENTRIES_PER_INVOCATION_C}u; entry_index += 1u) { + let key = keys[entry_index]; + let digit = (key >> (sorting_pass_index * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); + ranks[entry_index] += sorting_shared_c.scan[digit + conflict_free_offset(digit)]; + sorting_shared_c.entries[ranks[entry_index]] = key; + } + workgroupBarrier(); + + // Add global offset + sorting_shared_c.scan[gl_LocalInvocationID.x + conflict_free_offset(gl_LocalInvocationID.x)] = global_digit_count - local_digit_offset; + workgroupBarrier(); + + // Store keys from shared memory into global memory + for(var entry_index = 0u; entry_index < #{ENTRIES_PER_INVOCATION_C}u; entry_index += 1u) { + let key = sorting_shared_c.entries[#{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x]; + let digit = (key >> (sorting_pass_index * #{RADIX_BITS_PER_DIGIT}u)) & (#{RADIX_BASE}u - 1u); + keys[entry_index] = digit; + output_entries[sorting_shared_c.scan[digit + conflict_free_offset(digit)] + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x][0] = key; + } + workgroupBarrier(); + + // Load values from global memory and scatter them inside shared memory + for(var entry_index = 0u; entry_index < #{ENTRIES_PER_INVOCATION_C}u; entry_index += 1u) { + let value = input_entries[global_entry_offset + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x][1]; + sorting_shared_c.entries[ranks[entry_index]] = value; + } + workgroupBarrier(); + + // Store values from shared memory into global memory + for(var entry_index = 0u; entry_index < #{ENTRIES_PER_INVOCATION_C}u; entry_index += 1u) { + let value = sorting_shared_c.entries[#{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x]; + let digit = keys[entry_index]; + output_entries[sorting_shared_c.scan[digit + conflict_free_offset(digit)] + #{WORKGROUP_INVOCATIONS_C}u * entry_index + gl_LocalInvocationID.x][1] = value; + } +} + +@compute @workgroup_size(#{TEMPORAL_SORT_WINDOW_SIZE}) +fn temporal_sort_flip( + @builtin(local_invocation_id) gl_LocalInvocationID: vec3, + @builtin(global_invocation_id) gl_GlobalInvocationID: vec3, +) { + // let start_index = gl_GlobalInvocationID.x * #{TEMPORAL_SORT_WINDOW_SIZE}u; + // let end_index = start_index + #{TEMPORAL_SORT_WINDOW_SIZE}u; +} + +@compute @workgroup_size(#{TEMPORAL_SORT_WINDOW_SIZE}) +fn temporal_sort_flop( + @builtin(local_invocation_id) gl_LocalInvocationID: vec3, + @builtin(global_invocation_id) gl_GlobalInvocationID: vec3, +) { + // // TODO: pad sorting buffers to 1.5 window size + // let start_index = gl_GlobalInvocationID.x * #{TEMPORAL_SORT_WINDOW_SIZE}u + #{TEMPORAL_SORT_WINDOW_SIZE}u / 2u; + // let end_index = start_index + #{TEMPORAL_SORT_WINDOW_SIZE}u; + + // // pair sort entries in window size + // for (var i = start_index; i < end_index; i += 2u) { + // let pos_a = points[input_entries[i][0]].position.xyz; + // let depth_a = world_to_clip(pos_a).z; + // } +} + + + // https://github.com/cvlab-epfl/gaussian-splatting-web/blob/905b3c0fb8961e42c79ef97e64609e82383ca1c2/src/shaders.ts#L185 // TODO: precompute -fn compute_cov3d(scale: vec3, rot: vec4) -> array { +fn compute_cov3d(scale: vec3, rotation: vec4) -> array { let S = scale * uniforms.global_scale; - let r = rot.x; - let x = rot.y; - let y = rot.z; - let z = rot.w; + let r = rotation.x; + let x = rotation.y; + let y = rotation.z; + let z = rotation.w; let R = mat3x3( 1.0 - 2.0 * (y * y + z * z), @@ -76,25 +335,30 @@ fn compute_cov3d(scale: vec3, rot: vec4) -> array { ); } -fn compute_cov2d(position: vec3, scale: vec3, rot: vec4) -> vec3 { - let cov3d = compute_cov3d(scale, rot); +fn compute_cov2d(position: vec3, scale: vec3, rotation: vec4) -> vec3 { + let cov3d = compute_cov3d(scale, rotation); let Vrk = mat3x3( cov3d[0], cov3d[1], cov3d[2], cov3d[1], cov3d[3], cov3d[4], cov3d[2], cov3d[4], cov3d[5], ); - // TODO: resolve metal vs directx differences var t = view.inverse_view * vec4(position, 1.0); - let focal_x = 500.0; - let focal_y = 500.0; +#ifdef USE_AABB + let focal_x = view.viewport.z / (2.0 * view.projection[0][0]); + let focal_y = view.viewport.w / (2.0 * view.projection[1][1]); +#endif + +#ifdef USE_OBB + let focal_x = view.viewport.z / (2.0 * view.inverse_projection[0][0]); + let focal_y = view.viewport.w / (2.0 * view.inverse_projection[1][1]); +#endif let limx = 1.3 * 0.5 * view.viewport.z / focal_x; let limy = 1.3 * 0.5 * view.viewport.w / focal_y; let txtz = t.x / t.z; let tytz = t.y / t.z; - t.x = min(limx, max(-limx, txtz)) * t.z; t.y = min(limy, max(-limy, tytz)) * t.z; @@ -104,12 +368,23 @@ fn compute_cov2d(position: vec3, scale: vec3, rot: vec4) -> vec3< -(focal_x * t.x) / (t.z * t.z), 0.0, - -focal_y / t.z, - (focal_y * t.y) / (t.z * t.z), + focal_y / t.z, + -(focal_y * t.y) / (t.z * t.z), 0.0, 0.0, 0.0, ); +#ifdef USE_AABB + let W = transpose( + mat3x3( + view.inverse_view.x.xyz, + view.inverse_view.y.xyz, + view.inverse_view.z.xyz, + ) + ); +#endif + +#ifdef USE_OBB let W = transpose( mat3x3( view.inverse_view.x.xyz, @@ -117,6 +392,7 @@ fn compute_cov2d(position: vec3, scale: vec3, rot: vec4) -> vec3< view.inverse_view.z.xyz, ) ); +#endif let T = W * J; @@ -129,8 +405,8 @@ fn compute_cov2d(position: vec3, scale: vec3, rot: vec4) -> vec3< fn world_to_clip(world_pos: vec3) -> vec4 { - let homogenous_pos = view.view_proj * vec4(world_pos, 1.0); - return homogenous_pos / homogenous_pos.w; + let homogenous_pos = view.projection * view.inverse_view * vec4(world_pos, 1.0); + return homogenous_pos / (homogenous_pos.w + 0.000000001); } fn in_frustum(clip_space_pos: vec3) -> bool { @@ -147,13 +423,22 @@ fn get_bounding_box_corner( // return vec4(offset, uv); let det = cov2d.x * cov2d.z - cov2d.y * cov2d.y; - let mid = 0.5 * (cov2d.x + cov2d.z); let lambda1 = mid + sqrt(max(0.1, mid * mid - det)); let lambda2 = mid - sqrt(max(0.1, mid * mid - det)); let x_axis_length = sqrt(lambda1); let y_axis_length = sqrt(lambda2); + // let det = cov2d.x * cov2d.z - cov2d.y * cov2d.y; + // let mid = 0.5 * (cov2d.x + cov2d.z); + // var discriminant = max(0.0, mid * mid - det); + + // let lambda1 = mid + sqrt(discriminant); + // let lambda2 = mid - sqrt(discriminant); + // let x_axis_length = sqrt(lambda1); + // let y_axis_length = sqrt(lambda2); + + #ifdef USE_AABB // creates a square AABB (inefficient fragment usage) let radius_px = 3.5 * max(x_axis_length, y_axis_length); @@ -169,16 +454,20 @@ fn get_bounding_box_corner( #endif #ifdef USE_OBB + // TODO: shouldn't require 3.5 stdevs + let bounds = vec2( + x_axis_length, + y_axis_length, + ); + // bounding box is aligned to the eigenvectors with proper width/height // collapse unstable eigenvectors to circle let threshold = 0.1; if (abs(lambda1 - lambda2) < threshold) { + let circle = direction * max(x_axis_length, y_axis_length); return vec4( - vec2( - direction.x * (x_axis_length + y_axis_length) * 0.5, - direction.y * x_axis_length - ) / view.viewport.zw, - direction * x_axis_length + circle / view.viewport.zw, + circle ); } @@ -186,19 +475,19 @@ fn get_bounding_box_corner( cov2d.y, lambda1 - cov2d.x )); - let eigvec2 = vec2(-eigvec1.y, eigvec1.x); + let eigvec2 = vec2( + -eigvec1.y, + eigvec1.x + ); let rotation_matrix = mat2x2( eigvec1.x, eigvec2.x, eigvec1.y, eigvec2.y ); - let scaled_vertex = vec2( - direction.x * x_axis_length, - direction.y * y_axis_length - ); + let scaled_vertex = direction * bounds; return vec4( - rotation_matrix * (scaled_vertex / view.viewport.zw), + (scaled_vertex / view.viewport.z) * rotation_matrix, scaled_vertex ); #endif @@ -211,12 +500,21 @@ fn vs_points( @builtin(vertex_index) vertex_index: u32, ) -> GaussianOutput { var output: GaussianOutput; - let point = points[instance_index]; - let transformed_position = (uniforms.global_transform * vec4(point.position, 1.0)).xyz; + let splat_index = sorted_entries[instance_index][1]; + + let discard_quad = sorted_entries[instance_index][0] == 0xFFFFFFFFu; + if (discard_quad) { + output.color = vec4(0.0, 0.0, 0.0, 0.0); + return output; + } + + let point = points[splat_index]; + let transformed_position = (uniforms.global_transform * point.position).xyz; let projected_position = world_to_clip(transformed_position); if (!in_frustum(projected_position.xyz)) { output.color = vec4(0.0, 0.0, 0.0, 0.0); + output.position = vec4(0.0, 0.0, 0.0, 0.0); return output; } @@ -233,10 +531,10 @@ fn vs_points( let ray_direction = normalize(transformed_position - view.world_position); output.color = vec4( spherical_harmonics_lookup(ray_direction, point.sh), - point.opacity + point.scale_opacity.a ); - let cov2d = compute_cov2d(transformed_position, point.scale, point.rotation); + let cov2d = compute_cov2d(transformed_position, point.scale_opacity.rgb, point.rotation); // TODO: remove conic when OBB is used let det = cov2d.x * cov2d.z - cov2d.y * cov2d.y; @@ -266,6 +564,8 @@ fn vs_points( @fragment fn fs_main(input: GaussianOutput) -> @location(0) vec4 { // TODO: draw gaussian without conic (OBB) + +#ifdef USE_AABB let d = -input.major_minor; let conic = input.conic; let power = -0.5 * (conic.x * d.x * d.x + conic.z * d.y * d.y) + conic.y * d.x * d.y; @@ -273,6 +573,20 @@ fn fs_main(input: GaussianOutput) -> @location(0) vec4 { if (power > 0.0) { discard; } +#endif + +#ifdef USE_OBB + let norm_uv = input.uv * 2.0 - 1.0; + let sigma = 1.0 / 3.5; + let sigma_squared = sigma * sigma; + let distance_squared = dot(norm_uv, norm_uv); + + let power = -distance_squared / (2.0 * sigma_squared); + + if (distance_squared > 3.5 * 3.5) { + discard; + } +#endif #ifdef VISUALIZE_BOUNDING_BOX let uv = input.uv; @@ -285,9 +599,10 @@ fn fs_main(input: GaussianOutput) -> @location(0) vec4 { } #endif - let alpha = min(0.99, input.color.a * exp(power)); + let alpha = exp(power); + let final_alpha = alpha * input.color.a; return vec4( - input.color.rgb * alpha, - alpha, + input.color.rgb * final_alpha, + final_alpha, ); } diff --git a/src/render/mod.rs b/src/render/mod.rs index 4cdf8e32..4febd734 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -7,7 +7,10 @@ use bevy::{ HandleUntyped, LoadState, }, - core_pipeline::core_3d::Transparent3d, + core_pipeline::core_3d::{ + Transparent3d, + CORE_3D, + }, ecs::{ system::{ lifetimeless::*, @@ -27,7 +30,6 @@ use bevy::{ GlobalsUniform, GlobalsBuffer, }, - mesh::GpuBufferInfo, render_asset::{ PrepareAssetError, RenderAsset, @@ -45,7 +47,10 @@ use bevy::{ TrackedRenderPass, }, render_resource::*, - renderer::RenderDevice, + renderer::{ + RenderDevice, + RenderContext, + }, Render, RenderApp, RenderSet, @@ -55,6 +60,10 @@ use bevy::{ ViewUniforms, ViewUniformOffset, }, + render_graph::{ + self, + RenderGraphApp, + }, }, }; @@ -69,6 +78,10 @@ use crate::gaussian::{ const GAUSSIAN_SHADER_HANDLE: HandleUntyped = HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 68294581); const SPHERICAL_HARMONICS_SHADER_HANDLE: HandleUntyped = HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 834667312); +pub mod node { + pub const RADIX_SORT: &str = "radix_sort"; +} + #[derive(Default)] pub struct RenderPipelinePlugin; @@ -93,6 +106,17 @@ impl Plugin for RenderPipelinePlugin { app.add_plugins(UniformComponentPlugin::::default()); if let Ok(render_app) = app.get_sub_app_mut(RenderApp) { + render_app + .add_render_graph_node::( + CORE_3D, + node::RADIX_SORT, + ) + .add_render_graph_edge( + CORE_3D, + node::RADIX_SORT, + bevy::core_pipeline::core_3d::graph::node::PREPASS, + ); + render_app .add_render_command::() .init_resource::() @@ -127,9 +151,14 @@ pub struct GpuGaussianSplattingBundle { #[derive(Debug, Clone)] pub struct GpuGaussianCloud { - pub buffer: Buffer, + pub gaussian_buffer: Buffer, pub count: u32, - pub buffer_info: GpuBufferInfo, + + pub draw_indirect_buffer: Buffer, + pub sorting_global_buffer: Buffer, + pub sorting_pass_buffers: [Buffer; 4], + pub entry_buffer_a: Buffer, + pub entry_buffer_b: Buffer, } impl RenderAsset for GaussianCloud { type ExtractedAsset = GaussianCloud; @@ -144,16 +173,63 @@ impl RenderAsset for GaussianCloud { gaussian_cloud: Self::ExtractedAsset, render_device: &mut SystemParamItem, ) -> Result> { - let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor { + let gaussian_buffer = render_device.create_buffer_with_data(&BufferInitDescriptor { label: Some("gaussian cloud buffer"), contents: bytemuck::cast_slice(gaussian_cloud.0.as_slice()), usage: BufferUsages::VERTEX | BufferUsages::COPY_DST | BufferUsages::STORAGE, }); + let count = gaussian_cloud.0.len() as u32; + + // TODO: derive sorting_buffer_size from cloud count (with possible rounding to next power of 2) + let sorting_global_buffer = render_device.create_buffer(&BufferDescriptor { + label: Some("sorting global buffer"), + size: ShaderDefines::default().sorting_buffer_size as u64, + usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let draw_indirect_buffer = render_device.create_buffer(&BufferDescriptor { + label: Some("draw indirect buffer"), + size: std::mem::size_of::() as u64, + usage: BufferUsages::INDIRECT | BufferUsages::COPY_DST | BufferUsages::STORAGE | BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let sorting_pass_buffers = (0..4) + .map(|idx| { + render_device.create_buffer_with_data(&BufferInitDescriptor { + label: format!("sorting pass buffer {}", idx).as_str().into(), + contents: &[idx as u8, 0, 0, 0], + usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST, + }) + }) + .collect::>() + .try_into() + .unwrap(); + + let entry_buffer_a = render_device.create_buffer(&BufferDescriptor { + label: Some("entry buffer a"), + size: (count as usize * std::mem::size_of::<(u32, u32)>()) as u64, + usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let entry_buffer_b = render_device.create_buffer(&BufferDescriptor { + label: Some("entry buffer b"), + size: (count as usize * std::mem::size_of::<(u32, u32)>()) as u64, + usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + Ok(GpuGaussianCloud { - buffer, - count: gaussian_cloud.0.len() as u32, - buffer_info: GpuBufferInfo::NonIndexed, + gaussian_buffer, + count, + draw_indirect_buffer, + sorting_global_buffer, + sorting_pass_buffers, + entry_buffer_a, + entry_buffer_b, }) } } @@ -171,7 +247,10 @@ fn queue_gaussians( &Handle, &GaussianCloudSettings, )>, - mut views: Query<(&ExtractedView, &mut RenderPhase)>, + mut views: Query<( + &ExtractedView, + &mut RenderPhase, + )>, ) { let draw_custom = transparent_3d_draw_functions.read().id::(); @@ -197,12 +276,18 @@ fn queue_gaussians( } + + #[derive(Resource)] pub struct GaussianCloudPipeline { shader: Handle, pub gaussian_cloud_layout: BindGroupLayout, pub gaussian_uniform_layout: BindGroupLayout, pub view_layout: BindGroupLayout, + pub radix_sort_layout: BindGroupLayout, + pub radix_sort_pipelines: [CachedComputePipelineId; 3], + pub temporal_sort_pipelines: [CachedComputePipelineId; 2], + pub sorted_layout: BindGroupLayout, } impl FromWorld for GaussianCloudPipeline { @@ -212,7 +297,7 @@ impl FromWorld for GaussianCloudPipeline { let view_layout_entries = vec![ BindGroupLayoutEntry { binding: 0, - visibility: ShaderStages::VERTEX_FRAGMENT, + visibility: ShaderStages::all(), ty: BindingType::Buffer { ty: BufferBindingType::Uniform, has_dynamic_offset: true, @@ -222,7 +307,7 @@ impl FromWorld for GaussianCloudPipeline { }, BindGroupLayoutEntry { binding: 1, - visibility: ShaderStages::VERTEX_FRAGMENT, + visibility: ShaderStages::all(), ty: BindingType::Buffer { ty: BufferBindingType::Uniform, has_dynamic_offset: false, @@ -239,10 +324,10 @@ impl FromWorld for GaussianCloudPipeline { let gaussian_uniform_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor { label: Some("gaussian_uniform_layout"), - entries: &vec![ + entries: &[ BindGroupLayoutEntry { binding: 0, - visibility: ShaderStages::VERTEX_FRAGMENT, + visibility: ShaderStages::all(), ty: BindingType::Buffer { ty: BufferBindingType::Uniform, has_dynamic_offset: true, @@ -255,10 +340,10 @@ impl FromWorld for GaussianCloudPipeline { let gaussian_cloud_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor { label: Some("gaussian_cloud_layout"), - entries: &vec![ + entries: &[ BindGroupLayoutEntry { binding: 0, - visibility: ShaderStages::VERTEX_FRAGMENT, + visibility: ShaderStages::all(), ty: BindingType::Buffer { ty: BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, @@ -269,15 +354,246 @@ impl FromWorld for GaussianCloudPipeline { ], }); + let sorting_buffer_entry = BindGroupLayoutEntry { + binding: 1, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: BufferSize::new(ShaderDefines::default().sorting_buffer_size as u64), + }, + count: None, + }; + + let draw_indirect_buffer_entry = BindGroupLayoutEntry { + binding: 2, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: BufferSize::new(std::mem::size_of::() as u64), + }, + count: None, + }; + + let radix_sort_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor { + label: Some("radix_sort_layout"), + entries: &[ + BindGroupLayoutEntry { + binding: 0, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Uniform, + has_dynamic_offset: false, + min_binding_size: BufferSize::new(std::mem::size_of::() as u64), + }, + count: None, + }, + sorting_buffer_entry, + draw_indirect_buffer_entry, + BindGroupLayoutEntry { + binding: 3, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64), + }, + count: None, + }, + BindGroupLayoutEntry { + binding: 4, + visibility: ShaderStages::COMPUTE, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64), + }, + count: None, + }, + ], + }); + + let sorted_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor { + label: Some("sorted_layout"), + entries: &vec![ + BindGroupLayoutEntry { + binding: 5, + visibility: ShaderStages::VERTEX_FRAGMENT, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64), + }, + count: None, + }, + ], + }); + + let compute_layout = vec![ + view_layout.clone(), + gaussian_uniform_layout.clone(), + gaussian_cloud_layout.clone(), + radix_sort_layout.clone(), + ]; + let shader = GAUSSIAN_SHADER_HANDLE.typed(); + let shader_defs = shader_defs(false, false); + + let pipeline_cache = render_world.resource::(); + let radix_sort_a = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + label: Some("radix_sort_a".into()), + layout: compute_layout.clone(), + push_constant_ranges: vec![], + shader: shader.clone(), + shader_defs: shader_defs.clone(), + entry_point: "radix_sort_a".into(), + }); + + let radix_sort_b = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + label: Some("radix_sort_b".into()), + layout: compute_layout.clone(), + push_constant_ranges: vec![], + shader: shader.clone(), + shader_defs: shader_defs.clone(), + entry_point: "radix_sort_b".into(), + }); + + let radix_sort_c = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + label: Some("radix_sort_c".into()), + layout: compute_layout.clone(), + push_constant_ranges: vec![], + shader: shader.clone(), + shader_defs: shader_defs.clone(), + entry_point: "radix_sort_c".into(), + }); + + + let temporal_sort_flip = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + label: Some("temporal_sort_flip".into()), + layout: compute_layout.clone(), + push_constant_ranges: vec![], + shader: shader.clone(), + shader_defs: shader_defs.clone(), + entry_point: "temporal_sort_flip".into(), + }); + + let temporal_sort_flop = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + label: Some("temporal_sort_flop".into()), + layout: compute_layout.clone(), + push_constant_ranges: vec![], + shader: shader.clone(), + shader_defs: shader_defs.clone(), + entry_point: "temporal_sort_flop".into(), + }); + GaussianCloudPipeline { gaussian_cloud_layout, gaussian_uniform_layout, view_layout, - shader: GAUSSIAN_SHADER_HANDLE.typed(), + shader: shader.clone(), + radix_sort_layout, + radix_sort_pipelines: [ + radix_sort_a, + radix_sort_b, + radix_sort_c, + ], + temporal_sort_pipelines: [ + temporal_sort_flip, + temporal_sort_flop, + ], + sorted_layout, + } + } +} + +// TODO: allow setting shader defines via API +struct ShaderDefines { + radix_bits_per_digit: u32, + radix_digit_places: u32, + radix_base: u32, + entries_per_invocation_a: u32, + entries_per_invocation_c: u32, + workgroup_invocations_a: u32, + workgroup_invocations_c: u32, + workgroup_entries_a: u32, + workgroup_entries_c: u32, + max_tile_count_c: u32, + sorting_buffer_size: usize, + + temporal_sort_window_size: u32, +} + +impl Default for ShaderDefines { + fn default() -> Self { + let radix_bits_per_digit = 8; + let radix_digit_places = 32 / radix_bits_per_digit; + let radix_base = 1 << radix_bits_per_digit; + let entries_per_invocation_a = 4; + let entries_per_invocation_c = 4; + let workgroup_invocations_a = radix_base * radix_digit_places; + let workgroup_invocations_c = radix_base; + let workgroup_entries_a = workgroup_invocations_a * entries_per_invocation_a; + let workgroup_entries_c = workgroup_invocations_c * entries_per_invocation_c; + let max_tile_count_c = (10000000 + workgroup_entries_c - 1) / workgroup_entries_c; + let sorting_buffer_size = ( + radix_base as usize * + (radix_digit_places as usize + max_tile_count_c as usize) * + std::mem::size_of::() + ) + std::mem::size_of::() * 5; + + Self { + radix_bits_per_digit, + radix_digit_places, + radix_base, + entries_per_invocation_a, + entries_per_invocation_c, + workgroup_invocations_a, + workgroup_invocations_c, + workgroup_entries_a, + workgroup_entries_c, + max_tile_count_c, + sorting_buffer_size, + + temporal_sort_window_size: 16, } } } +fn shader_defs( + aabb: bool, + visualize_bounding_box: bool, +) -> Vec { + let defines = ShaderDefines::default(); + let mut shader_defs = vec![ + ShaderDefVal::UInt("MAX_SH_COEFF_COUNT".into(), MAX_SH_COEFF_COUNT as u32), + ShaderDefVal::UInt("RADIX_BASE".into(), defines.radix_base), + ShaderDefVal::UInt("RADIX_BITS_PER_DIGIT".into(), defines.radix_bits_per_digit), + ShaderDefVal::UInt("RADIX_DIGIT_PLACES".into(), defines.radix_digit_places), + ShaderDefVal::UInt("ENTRIES_PER_INVOCATION_A".into(), defines.entries_per_invocation_a), + ShaderDefVal::UInt("ENTRIES_PER_INVOCATION_C".into(), defines.entries_per_invocation_c), + ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_A".into(), defines.workgroup_invocations_a), + ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_C".into(), defines.workgroup_invocations_c), + ShaderDefVal::UInt("WORKGROUP_ENTRIES_C".into(), defines.workgroup_entries_c), + ShaderDefVal::UInt("MAX_TILE_COUNT_C".into(), defines.max_tile_count_c), + + ShaderDefVal::UInt("TEMPORAL_SORT_WINDOW_SIZE".into(), defines.temporal_sort_window_size), + ]; + + if aabb { + shader_defs.push("USE_AABB".into()); + } + + if !aabb { + shader_defs.push("USE_OBB".into()); + } + + if visualize_bounding_box { + shader_defs.push("VISUALIZE_BOUNDING_BOX".into()); + } + + shader_defs +} + #[derive(PartialEq, Eq, Hash, Clone, Copy)] pub struct GaussianCloudPipelineKey { pub aabb: bool, @@ -288,28 +604,18 @@ impl SpecializedRenderPipeline for GaussianCloudPipeline { type Key = GaussianCloudPipelineKey; fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor { - let mut shader_defs = vec![ - ShaderDefVal::UInt("MAX_SH_COEFF_COUNT".into(), MAX_SH_COEFF_COUNT as u32), - ]; - - if key.aabb { - shader_defs.push("USE_AABB".into()); - } - - if !key.aabb { - shader_defs.push("USE_OBB".into()); - } - - if key.visualize_bounding_box { - shader_defs.push("VISUALIZE_BOUNDING_BOX".into()); - } + let shader_defs = shader_defs( + key.aabb, + key.visualize_bounding_box, + ); RenderPipelineDescriptor { - label: Some("gaussian cloud pipeline".into()), + label: Some("gaussian cloud render pipeline".into()), layout: vec![ self.view_layout.clone(), self.gaussian_uniform_layout.clone(), self.gaussian_cloud_layout.clone(), + self.sorted_layout.clone(), ], vertex: VertexState { shader: self.shader.clone(), @@ -428,7 +734,9 @@ pub struct GaussianUniformBindGroups { #[derive(Component)] pub struct GaussianCloudBindGroup { - pub bind_group: BindGroup, + pub cloud_bind_group: BindGroup, + pub radix_sort_bind_groups: [BindGroup; 4], + pub sorted_bind_group: BindGroup, } pub fn queue_gaussian_bind_group( @@ -476,21 +784,101 @@ pub fn queue_gaussian_bind_group( let cloud = gaussian_cloud_res.get(cloud_handle).unwrap(); + let sorting_global_entry = BindGroupEntry { + binding: 1, + resource: BindingResource::Buffer(BufferBinding { + buffer: &cloud.sorting_global_buffer, + offset: 0, + size: BufferSize::new(cloud.sorting_global_buffer.size()), + }), + }; + + let draw_indirect_entry = BindGroupEntry { + binding: 2, + resource: BindingResource::Buffer(BufferBinding { + buffer: &cloud.draw_indirect_buffer, + offset: 0, + size: BufferSize::new(cloud.draw_indirect_buffer.size()), + }), + }; + + let radix_sort_bind_groups: [BindGroup; 4] = (0..4) + .map(|idx| { + render_device.create_bind_group(&BindGroupDescriptor { + label: format!("radix_sort_bind_group {}", idx).as_str().into(), + layout: &gaussian_cloud_pipeline.radix_sort_layout, + entries: &[ + BindGroupEntry { + binding: 0, + resource: BindingResource::Buffer(BufferBinding { + buffer: &cloud.sorting_pass_buffers[idx], + offset: 0, + size: BufferSize::new(std::mem::size_of::() as u64), + }), + }, + sorting_global_entry.clone(), + draw_indirect_entry.clone(), + BindGroupEntry { + binding: 3, + resource: BindingResource::Buffer(BufferBinding { + buffer: if idx % 2 == 0 { + &cloud.entry_buffer_a + } else { + &cloud.entry_buffer_b + }, + offset: 0, + size: BufferSize::new((cloud.count as usize * std::mem::size_of::<(u32, u32)>()) as u64), + }), + }, + BindGroupEntry { + binding: 4, + resource: BindingResource::Buffer(BufferBinding { + buffer: if idx % 2 == 0 { + &cloud.entry_buffer_b + } else { + &cloud.entry_buffer_a + }, + offset: 0, + size: BufferSize::new((cloud.count as usize * std::mem::size_of::<(u32, u32)>()) as u64), + }), + }, + ], + }) + }) + .collect::>() + .try_into() + .unwrap(); + commands.entity(entity).insert(GaussianCloudBindGroup { - bind_group: render_device.create_bind_group(&BindGroupDescriptor { + cloud_bind_group: render_device.create_bind_group(&BindGroupDescriptor { entries: &[ BindGroupEntry { binding: 0, resource: BindingResource::Buffer(BufferBinding { - buffer: &cloud.buffer, + buffer: &cloud.gaussian_buffer, offset: 0, - size: BufferSize::new(cloud.buffer.size()), + size: BufferSize::new(cloud.gaussian_buffer.size()), }), }, ], layout: &gaussian_cloud_pipeline.gaussian_cloud_layout, label: Some("gaussian_cloud_bind_group"), }), + radix_sort_bind_groups, + sorted_bind_group: render_device.create_bind_group(&BindGroupDescriptor { + entries: &[ + BindGroupEntry { + binding: 5, + resource: BindingResource::Buffer(BufferBinding { + buffer: &cloud.entry_buffer_a, + offset: 0, + size: BufferSize::new((cloud.count as usize * std::mem::size_of::<(u32, u32)>()) as u64), + }), + }, + ], + layout: &gaussian_cloud_pipeline.sorted_layout, + label: Some("render_sorted_bind_group"), + }), }); } } @@ -622,7 +1010,13 @@ impl RenderCommand

for DrawGaussianInstanced { fn render<'w>( _item: &P, _view: (), - (handle, bind_group): (&'w Handle, &'w GaussianCloudBindGroup), + ( + handle, + bind_groups, + ): ( + &'w Handle, + &'w GaussianCloudBindGroup, + ), gaussian_clouds: SystemParamItem<'w, '_, Self::Param>, pass: &mut TrackedRenderPass<'w>, ) -> RenderCommandResult { @@ -631,24 +1025,194 @@ impl RenderCommand

for DrawGaussianInstanced { None => return RenderCommandResult::Failure, }; - pass.set_bind_group(2, &bind_group.bind_group, &[]); + pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]); + pass.set_bind_group(3, &bind_groups.sorted_bind_group, &[]); + + pass.draw_indirect(&gpu_gaussian_cloud.draw_indirect_buffer, 0); + + RenderCommandResult::Success + } +} + - match &gpu_gaussian_cloud.buffer_info { - GpuBufferInfo::Indexed { - buffer, - index_format, - count, - } => { - pass.set_index_buffer(buffer.slice(..), 0, *index_format); - pass.draw_indexed(0..*count, 0, 0..gpu_gaussian_cloud.count as u32); + + + +struct RadixSortNode { + gaussian_clouds: QueryState<( + &'static Handle, + &'static GaussianCloudBindGroup + )>, + initialized: bool, + pipeline_idx: Option, + view_bind_group: QueryState<( + &'static GaussianViewBindGroup, + &'static ViewUniformOffset, + )>, +} + +impl FromWorld for RadixSortNode { + fn from_world(world: &mut World) -> Self { + Self { + gaussian_clouds: world.query(), + initialized: false, + pipeline_idx: None, + view_bind_group: world.query(), + } + } +} + +impl render_graph::Node for RadixSortNode { + fn update(&mut self, world: &mut World) { + let pipeline = world.resource::(); + let pipeline_cache = world.resource::(); + + if !self.initialized { + let mut pipelines_loaded = true; + for sort_pipeline in pipeline.radix_sort_pipelines.iter() { + if let CachedPipelineState::Ok(_) = + pipeline_cache.get_compute_pipeline_state(*sort_pipeline) + { + continue; + } + + pipelines_loaded = false; } - GpuBufferInfo::NonIndexed => { - pass.draw(0..4, 0..gpu_gaussian_cloud.count as u32); + + self.initialized = pipelines_loaded; + + if !self.initialized { + return; } - // TODO: add support for indirect draw and match over sort methods } - RenderCommandResult::Success + + if self.pipeline_idx.is_none() { + self.pipeline_idx = Some(0); + } else { + self.pipeline_idx = Some((self.pipeline_idx.unwrap() + 1) % pipeline.radix_sort_pipelines.len() as u32); + } + + self.gaussian_clouds.update_archetypes(world); + self.view_bind_group.update_archetypes(world); } + fn run( + &self, + _graph: &mut render_graph::RenderGraphContext, + render_context: &mut RenderContext, + world: &World, + ) -> Result<(), render_graph::NodeRunError> { + if !self.initialized || self.pipeline_idx.is_none() { + return Ok(()); + } + + let _idx = self.pipeline_idx.unwrap() as usize; // TODO: temporal sort + + let pipeline_cache = world.resource::(); + let pipeline = world.resource::(); + let gaussian_uniforms = world.resource::(); + let command_encoder = render_context.command_encoder(); + + for ( + view_bind_group, + view_uniform_offset, + ) in self.view_bind_group.iter_manual(world) { + for ( + cloud_handle, + cloud_bind_group + ) in self.gaussian_clouds.iter_manual(world) { + let cloud = world.get_resource::>().unwrap().get(cloud_handle).unwrap(); + + let radix_digit_places = ShaderDefines::default().radix_digit_places; + + command_encoder.clear_buffer( + &cloud.sorting_global_buffer, + 0, + None, + ); + + { + let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default()); + + // TODO: view/global + pass.set_bind_group( + 0, + &view_bind_group.value, + &[view_uniform_offset.offset], + ); + pass.set_bind_group( + 1, + gaussian_uniforms.base_bind_group.as_ref().unwrap(), + &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex + ); + pass.set_bind_group( + 2, + &cloud_bind_group.cloud_bind_group, + &[] + ); + pass.set_bind_group( + 3, + &cloud_bind_group.radix_sort_bind_groups[1], + &[], + ); + + let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap(); + pass.set_pipeline(radix_sort_a); + + let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a; + pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1); + + + let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap(); + pass.set_pipeline(radix_sort_b); + + pass.dispatch_workgroups(1, radix_digit_places, 1); + } + + for pass_idx in 0..radix_digit_places { + if pass_idx > 0 { + let size = ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c * std::mem::size_of::() as u32; + command_encoder.clear_buffer( + &cloud.sorting_global_buffer, + 0, + std::num::NonZeroU64::new(size as u64).unwrap().into() + ); + } + + let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default()); + + let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap(); + pass.set_pipeline(&radix_sort_c); + + pass.set_bind_group( + 0, + &view_bind_group.value, + &[view_uniform_offset.offset], + ); + pass.set_bind_group( + 1, + gaussian_uniforms.base_bind_group.as_ref().unwrap(), + &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex + ); + pass.set_bind_group( + 2, + &cloud_bind_group.cloud_bind_group, + &[] + ); + pass.set_bind_group( + 3, + &cloud_bind_group.radix_sort_bind_groups[pass_idx as usize], + &[], + ); + + let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c; + pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1); + } + } + } + + + Ok(()) + } } diff --git a/src/render/spherical_harmonics.wgsl b/src/render/spherical_harmonics.wgsl index 3daafa37..e86cfae1 100644 --- a/src/render/spherical_harmonics.wgsl +++ b/src/render/spherical_harmonics.wgsl @@ -29,9 +29,9 @@ fn spherical_harmonics_lookup( color += shc[ 0] * vec3(sh[0], sh[1], sh[2]); - color += shc[ 1] * vec3(sh[3], sh[4], sh[5]) * ray_direction.y; - color += shc[ 2] * vec3(sh[6], sh[7], sh[8]) * ray_direction.z; - color += shc[ 3] * vec3(sh[9], sh[10], sh[11]) * ray_direction.x; + color += shc[ 1] * vec3(sh[ 3], sh[ 4], sh[ 5]) * ray_direction.y; + color += shc[ 2] * vec3(sh[ 6], sh[ 7], sh[ 8]) * ray_direction.z; + color += shc[ 3] * vec3(sh[ 9], sh[10], sh[11]) * ray_direction.x; color += shc[ 4] * vec3(sh[12], sh[13], sh[14]) * ray_direction.x * ray_direction.y; color += shc[ 5] * vec3(sh[15], sh[16], sh[17]) * ray_direction.y * ray_direction.z; diff --git a/tools/README.md b/tools/README.md new file mode 100644 index 00000000..f84b13f8 --- /dev/null +++ b/tools/README.md @@ -0,0 +1,9 @@ +# bevy_gaussian_splatting tools + +## ply to gcloud converter + +convert ply files into bevy_gaussian_splatting gcloud file format (more efficient) + +```bash +cargo run --bin ply_to_gcloud -- assets/scenes/icecream.ply +``` diff --git a/tools/ply_to_gcloud.rs b/tools/ply_to_gcloud.rs index ae167c8e..be65bfe1 100644 --- a/tools/ply_to_gcloud.rs +++ b/tools/ply_to_gcloud.rs @@ -32,6 +32,6 @@ fn main() { // write gloud.gz let gz_file = std::fs::File::create(&gcloud_filename).expect("failed to create file"); let mut gz_writer = std::io::BufWriter::new(gz_file); - let mut gz_encoder = GzEncoder::new(&mut gz_writer, Compression::default()); // TODO: consider switching to fast (or support multiple options), default is a bit slow + let mut gz_encoder = GzEncoder::new(&mut gz_writer, Compression::default()); serialize_into(&mut gz_encoder, &cloud).expect("failed to encode cloud"); } diff --git a/src/main.rs b/viewer/viewer.rs similarity index 87% rename from src/main.rs rename to viewer/viewer.rs index 3d6901fe..4e7b4fb9 100644 --- a/src/main.rs +++ b/viewer/viewer.rs @@ -1,11 +1,13 @@ use bevy::{ prelude::*, app::AppExit, + asset::ChangeWatcher, core::Name, diagnostic::{ DiagnosticsStore, FrameTimeDiagnosticsPlugin, }, + utils::Duration, }; use bevy_inspector_egui::quick::WorldInspectorPlugin; use bevy_panorbit_camera::{ @@ -49,6 +51,8 @@ fn setup_gaussian_cloud( mut commands: Commands, asset_server: Res, mut gaussian_assets: ResMut>, + // mut meshes: ResMut>, + // mut materials: ResMut>, ) { let cloud: Handle; let settings = GaussianCloudSettings { @@ -73,12 +77,22 @@ fn setup_gaussian_cloud( Name::new("gaussian_cloud"), )); + // commands.spawn(PbrBundle { + // mesh: meshes.add(Mesh::from(shape::Cube { size: 1.0 })), + // material: materials.add(Color::rgb(0.8, 0.3, 0.6).into()), + // transform: Transform::from_xyz(0.0, 0.0, 0.0), + // ..default() + // }); + commands.spawn(( Camera3dBundle { transform: Transform::from_translation(Vec3::new(0.0, 1.5, 5.0)), ..default() }, - PanOrbitCamera::default(), + PanOrbitCamera{ + allow_upside_down: true, + ..default() + }, )); } @@ -91,6 +105,10 @@ fn example_app() { app.insert_resource(ClearColor(Color::rgb_u8(0, 0, 0))); app.add_plugins( DefaultPlugins + .set(AssetPlugin { + watch_for_changes: ChangeWatcher::with_delay(Duration::from_millis(200)), + ..Default::default() + }) .set(ImagePlugin::default_nearest()) .set(WindowPlugin { primary_window: Some(Window {