diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e58f769..93c642e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - os: [windows-latest, macos-latest, macos-14] + os: [windows-latest, macos-latest] rust-toolchain: - nightly diff --git a/.github/workflows/clippy.yml b/.github/workflows/clippy.yml index 91746aa..fe9ac3e 100644 --- a/.github/workflows/clippy.yml +++ b/.github/workflows/clippy.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - os: [windows-latest, macos-latest, macos-14] + os: [windows-latest, macos-latest] rust-toolchain: - nightly diff --git a/.github/workflows/deploy-pages.yml b/.github/workflows/deploy-pages.yml new file mode 100644 index 0000000..1b7a942 --- /dev/null +++ b/.github/workflows/deploy-pages.yml @@ -0,0 +1,57 @@ +name: deploy github pages + +on: + push: + branches: + - main + workflow_dispatch: + +permissions: + contents: write + +env: + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + + +jobs: + deploy: + runs-on: macos-latest + + steps: + - name: checkout repository + uses: actions/checkout@v3 + + - name: setup nightly rust toolchain with caching + uses: brndnmtthws/rust-action@v1 + with: + toolchain: nightly + components: rustfmt, clippy + enable-sccache: "false" + + - name: install wasm32-unknown-unknown + run: rustup target add wasm32-unknown-unknown + + - name: install wasm-bindgen-cli + run: cargo install wasm-bindgen-cli + + # - name: install wasm-opt + # run: cargo install wasm-opt --locked + + - name: build wasm artifacts + run: cargo build -p bevy_burn_dino --target wasm32-unknown-unknown --release --no-default-features --features "web" + + # - name: optimize wasm artifacts + # run: wasm-opt -O -ol 100 -s 100 -o ./target/wasm32-unknown-unknown/release/bevy_burn_dino_opt.wasm ./target/wasm32-unknown-unknown/release/bevy_burn_dino.wasm + + - name: generate bindings with wasm-bindgen + run: wasm-bindgen --out-dir ./crates/bevy_burn_dino/www/out/ --target web ./target/wasm32-unknown-unknown/release/bevy_burn_dino.wasm + + - name: copy assets + run: mkdir -p ./crates/bevy_burn_dino/www/assets && cp -r ./crates/bevy_burn_dino/assets/* ./crates/bevy_burn_dino/www/assets/ + + - name: deploy to github pages + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: ./www + branch: www diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0c519b7..7f47cbf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - os: [windows-latest, macos-latest, macos-14] + os: [windows-latest, macos-latest] rust-toolchain: - nightly diff --git a/.gitignore b/.gitignore index b6c2102..a9168da 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ Cargo.lock screenshots/ headless_output/ www/assets/ +crates/bevy_burn_dino/www/assets/ assets/images/ assets/models/ diff --git a/Cargo.toml b/Cargo.toml index 5d3bf5f..2ca594e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,8 @@ resolver = "2" [features] default = ["import"] + +benchmark = ["burn/autotune", "burn/fusion"] import = ["bevy_args", "burn-import", "clap", "serde"] @@ -50,19 +52,41 @@ serde = { version = "1.0", optional = true } [dependencies.burn] version = "0.15" +default-features = false features = [ - "autodiff", - "autotune", - "dataset", - "fusion", - "metrics", - "ndarray", - "network", - "train", - "tui", + # "autotune", + # "fusion", + "std", "wgpu", ] +[dependencies.burn-wgpu] +version = "0.15" +default-features = false +features = [ + "fusion", + "std", + # "template", +] + +[dependencies.cubecl] +version = "0.3" +default-features = false +features = [ + "linalg", + "std", + # "template", +] + +[dependencies.cubecl-runtime] +version = "0.3" +default-features = false +features = [ + "std", + "channel-mpsc", + # "template", +] + [target.'cfg(target_arch = "wasm32")'.dependencies] console_error_panic_hook = "0.1" @@ -98,15 +122,19 @@ opt-level = 3 opt-level = 1 [profile.release] +strip = "symbols" lto = "thin" codegen-units = 1 opt-level = 3 +panic = "abort" [profile.wasm-release] +strip = "symbols" inherits = "release" opt-level = "z" lto = "fat" codegen-units = 1 +panic = "abort" [lib] diff --git a/crates/bevy_burn_dino/Cargo.toml b/crates/bevy_burn_dino/Cargo.toml index 7142784..1ec948d 100644 --- a/crates/bevy_burn_dino/Cargo.toml +++ b/crates/bevy_burn_dino/Cargo.toml @@ -18,7 +18,7 @@ categories = [ [features] default = ["native"] -native = ["nokhwa"] +native = ["burn/autotune", "futures", "nokhwa"] web = [] editor = [] @@ -29,8 +29,9 @@ perftest = [] [dependencies] # bevy_args = "1.6" bevy_args = { git = "https://github.com/mosure/bevy_args.git", branch = "burn" } -burn_dino = { path = "../../" } +burn_dino = { path = "../../", default-features = false } clap = { version = "4.5", features = ["derive"] } +futures = { version = "0.3", optional = true } futures-intrusive = "0.5" image = { version = "0.25.2", default-features = false, features = ["png"] } nokhwa = { version = "0.10", features = ["input-native", "output-threaded"], optional = true } @@ -51,15 +52,21 @@ features = [ "bevy_text", "bevy_ui", "bevy_winit", + "custom_cursor", "default_font", "png", + # "webgpu", + "x11", ] [dependencies.burn] version = "0.15" +default-features = false features = [ - "autotune", - "fusion", + # "autotune", + # "fusion", + "std", + # "template", "wgpu", ] @@ -67,6 +74,7 @@ features = [ [target.'cfg(target_arch = "wasm32")'.dependencies] console_error_panic_hook = "0.1" wasm-bindgen = "0.2" +wasm-bindgen-futures = "0.4" [dependencies.web-sys] @@ -82,5 +90,5 @@ features = [ [[bin]] -name = "ui" +name = "bevy_burn_dino" path = "src/main.rs" diff --git a/crates/bevy_burn_dino/src/main.rs b/crates/bevy_burn_dino/src/main.rs index cbad70e..197a637 100644 --- a/crates/bevy_burn_dino/src/main.rs +++ b/crates/bevy_burn_dino/src/main.rs @@ -7,6 +7,7 @@ use bevy::{ DiagnosticsStore, FrameTimeDiagnosticsPlugin, }, + ecs::{system::SystemState, world::CommandQueue}, render::{ render_asset::RenderAssetUsages, render_resource::{ @@ -21,6 +22,7 @@ use bevy::{ }, RenderPlugin, }, + tasks::{block_on, futures_lite::future, AsyncComputeTaskPool, Task}, }; use bevy_args::{ parse_args, @@ -30,7 +32,7 @@ use bevy_args::{ }; use burn::{ prelude::*, - backend::wgpu::Wgpu, + backend::wgpu::{init_async, AutoGraphicsApi, Wgpu}, record::{FullPrecisionSettings, NamedMpkBytesRecorder, Recorder}, }; use image::{ @@ -73,7 +75,7 @@ pub struct DinoImportConfig { #[arg(long, default_value = "true")] pub press_esc_to_close: bool, - #[arg(long, default_value = "true")] + #[arg(long, default_value = "false")] pub show_fps: bool, #[arg(long, default_value = "518")] @@ -88,7 +90,7 @@ impl Default for DinoImportConfig { Self { pca_only: true, press_esc_to_close: true, - show_fps: true, + show_fps: false, // TODO: display inference fps (UI fps is decoupled via async compute pool) inference_height: 518, inference_width: 518, } @@ -163,7 +165,7 @@ fn preprocess_image( } -fn to_image( +async fn to_image( image: Tensor, upsample_height: u32, upsample_width: u32, @@ -171,7 +173,7 @@ fn to_image( let height = image.shape().dims[1]; let width = image.shape().dims[2]; - let image = image.to_data().to_vec::().unwrap(); + let image = image.to_data_async().await.to_vec::().unwrap(); let image = ImageBuffer::, Vec>::from_raw( width as u32, height as u32, @@ -184,17 +186,20 @@ fn to_image( } -fn process_frame( +// TODO: benchmark process_frame +async fn process_frame( input: RgbImage, - dino: Res>, - pca_transform: Res>, - image_handle: &Handle, - mut images: ResMut>, -) { - let input_tensor: Tensor = preprocess_image(input, &dino.config, &dino.device); + dino_config: DinoVisionTransformerConfig, + dino_model: Arc>>, + pca_model: Arc>>, + device: B::Device, +) -> Vec { + let input_tensor: Tensor = preprocess_image(input, &dino_config, &device); - let model = dino.model.lock().unwrap(); - let dino_features = model.forward(input_tensor.clone(), None).x_norm_patchtokens; + let dino_features = { + let model = dino_model.lock().unwrap(); + model.forward(input_tensor.clone(), None).x_norm_patchtokens + }; let batch = dino_features.shape().dims[0]; let elements = dino_features.shape().dims[1]; @@ -204,15 +209,19 @@ fn process_frame( let x = dino_features.reshape([n_samples, embedding_dim]); - let pca_transform = pca_transform.model.lock().unwrap(); - let mut pca_features = pca_transform.forward(x.clone()); + let mut pca_features = { + let pca_transform = pca_model.lock().unwrap(); + pca_transform.forward(x.clone()) + }; // pca min-max scaling for i in 0..3 { let slice = pca_features.clone().slice([0..n_samples, i..i+1]); - let slice_min = slice.clone().min().into_scalar(); - let slice_max = slice.clone().max().into_scalar(); - let scaled = slice.sub_scalar(slice_min).div_scalar(slice_max - slice_min); + let slice_min = slice.clone().min(); + let slice_max = slice.clone().max(); + let scaled = slice + .sub(slice_min.clone().unsqueeze()) + .div((slice_max - slice_min).unsqueeze()); pca_features = pca_features.slice_assign( [0..n_samples, i..i+1], @@ -224,13 +233,11 @@ fn process_frame( let pca_features = to_image( pca_features, - dino.config.image_size as u32, - dino.config.image_size as u32, - ); + dino_config.image_size as u32, + dino_config.image_size as u32, + ).await; - // TODO: share wgpu io between bevy/burn - let existing_image = images.get_mut(image_handle).unwrap(); - existing_image.data = pca_features.into_raw(); + pca_features.into_raw() } @@ -242,6 +249,7 @@ mod native { mpsc::{ self, Sender, + SyncSender, Receiver, TryRecvError, }, @@ -262,7 +270,7 @@ mod native { use once_cell::sync::OnceCell; pub static SAMPLE_RECEIVER: OnceCell>>> = OnceCell::new(); - pub static SAMPLE_SENDER: OnceCell> = OnceCell::new(); + pub static SAMPLE_SENDER: OnceCell> = OnceCell::new(); pub static APP_RUN_RECEIVER: OnceCell>>> = OnceCell::new(); pub static APP_RUN_SENDER: OnceCell> = OnceCell::new(); @@ -271,7 +279,7 @@ mod native { let ( sample_sender, sample_receiver, - ) = mpsc::channel(); + ) = mpsc::sync_channel(1); SAMPLE_RECEIVER.set(Arc::new(Mutex::new(sample_receiver))).unwrap(); SAMPLE_SENDER.set(sample_sender).unwrap(); @@ -316,11 +324,43 @@ mod native { } +#[cfg(feature = "web")] +mod web { + use std::cell::RefCell; + + use image::{ + DynamicImage, + RgbImage, + RgbaImage, + }; + use wasm_bindgen::prelude::*; + + thread_local! { + pub static SAMPLE_RECEIVER: RefCell> = RefCell::new(None); + } + + #[wasm_bindgen] + pub fn frame_input(pixel_data: &[u8], width: u32, height: u32) { + let rgba_image = RgbaImage::from_raw(width, height, pixel_data.to_vec()) + .expect("failed to create RgbImage"); + + // TODO: perform video element -> burn's webgpu texture conversion directly + // TODO: perform this conversion from tensors + let dynamic_image = DynamicImage::ImageRgba8(rgba_image); + let rgb_image: RgbImage = dynamic_image.to_rgb8(); + + SAMPLE_RECEIVER.with(|receiver| { + *receiver.borrow_mut() = Some(rgb_image); + }); + } +} + + #[derive(Resource)] struct DinoModel { config: DinoVisionTransformerConfig, device: B::Device, - model: Arc::>>, + model: Arc>>, } #[derive(Resource)] @@ -334,13 +374,12 @@ struct PcaFeatures { image: Handle, } + +#[derive(Component)] +struct ProcessImage(Task); + #[cfg(feature = "native")] -fn process_frames( - dino_model: Res>, - pca_transform: Res>, - pca_features_handle: Res, - images: ResMut>, -) { +fn receive_image() -> Option { let receiver = native::SAMPLE_RECEIVER.get().unwrap(); let mut last_image = None; @@ -351,22 +390,78 @@ fn process_frames( } } - if let Some(image) = last_image { - process_frame( - image, - dino_model, - pca_transform, - &pca_features_handle.image, - images, - ); - } + last_image } - -// TODO: web-sys ffi #[cfg(feature = "web")] -fn frame_input() { +fn receive_image() -> Option { + web::SAMPLE_RECEIVER.with(|receiver| { + receiver.borrow_mut().take() + }) +} + +fn process_frames( + mut commands: Commands, + dino_model: Res>, + pca_transform: Res>, + pca_features_handle: Res, + active_tasks: Query<&ProcessImage>, +) { + // TODO: move to config + // TODO: fix multiple in flight deadlock + let inference_max_in_flight = 1; + if active_tasks.iter().count() >= inference_max_in_flight { + return; + } + + if let Some(image) = receive_image() { + let thread_pool = AsyncComputeTaskPool::get(); + let entity = commands.spawn_empty().id(); + + let device = dino_model.device.clone(); + let dino_config = dino_model.config.clone(); + let dino_model = dino_model.model.clone(); + let image_handle = pca_features_handle.image.clone(); + let pca_model = pca_transform.model.clone(); + + let task = thread_pool.spawn(async move { + let img_data = process_frame( + image, + dino_config, + dino_model, + pca_model, + device, + ).await; + + let mut command_queue = CommandQueue::default(); + command_queue.push(move |world: &mut World| { + let mut system_state = SystemState::< + ResMut>, + >::new(world); + let mut images = system_state.get_mut(world); + + // TODO: share wgpu io between bevy/burn + let existing_image = images.get_mut(&image_handle).unwrap(); + existing_image.data = img_data; + + world + .entity_mut(entity) + .remove::(); + }); + + command_queue + }); + commands.entity(entity).insert(ProcessImage(task)); + } +} + +fn handle_tasks(mut commands: Commands, mut transform_tasks: Query<&mut ProcessImage>) { + for mut task in &mut transform_tasks { + if let Some(mut commands_queue) = block_on(future::poll_once(&mut task.0)) { + commands.append(&mut commands_queue); + } + } } @@ -397,20 +492,6 @@ fn setup_ui( ..default() }) .with_children(|builder| { - // TODO: view input - // builder.spawn(ImageBundle { - // style: Style { - // width: Val::Percent(100.0), - // height: Val::Percent(100.0), - // ..default() - // }, - // image: UiImage { - // texture: foreground, - // ..default() - // }, - // ..default() - // }); - builder.spawn(UiImage { image: pca_image.image.clone(), image_mode: NodeImageMode::Stretch, @@ -418,7 +499,7 @@ fn setup_ui( }); }); - commands.spawn(Camera2d::default()); + commands.spawn(Camera2d); } pub fn viewer_app() -> App { @@ -427,13 +508,15 @@ pub fn viewer_app() -> App { let mut app = App::new(); app.insert_resource(args.clone()); + let title = "bevy_burn_dino".to_string(); + #[cfg(target_arch = "wasm32")] let primary_window = Some(Window { // fit_canvas_to_parent: true, canvas: Some("#bevy".to_string()), mode: bevy::window::WindowMode::Windowed, prevent_default_event_handling: true, - title: args.name.clone(), + title, #[cfg(feature = "perftest")] present_mode: bevy::window::PresentMode::AutoNoVsync, @@ -448,7 +531,7 @@ pub fn viewer_app() -> App { mode: bevy::window::WindowMode::Windowed, prevent_default_event_handling: false, resolution: (1024.0, 1024.0).into(), - title: "bevy_burn_dino".to_string(), + title, #[cfg(feature = "perftest")] present_mode: bevy::window::PresentMode::AutoNoVsync, @@ -554,26 +637,35 @@ fn fps_update_system( } } -fn run_app() { - // TODO: move model load to startup/async task +async fn run_app() { + log("running app..."); + let device = Default::default(); + init_async::(&device, Default::default()).await; + + log("device created"); + let config = DinoVisionTransformerConfig { ..DinoVisionTransformerConfig::vits(None, None) // TODO: supply image size fron config }; let dino = load_model::(&config, &device); + log("dino model loaded"); + let pca_config = PcaTransformConfig::new( config.embedding_dimension, 3, ); let pca_transform = load_pca_model::(&pca_config, &device); + log("pca model loaded"); + let mut app = viewer_app(); app.init_resource::(); app.insert_resource(DinoModel { config, - device: device, + device, model: Arc::new(Mutex::new(dino)), }); app.insert_resource(PcaTransformModel { @@ -581,17 +673,46 @@ fn run_app() { }); app.add_systems(Startup, setup_ui); - app.add_systems(Update, process_frames); + app.add_systems( + Update, + ( + handle_tasks, + process_frames, + ), + ); + + log("running bevy app..."); app.run(); } +pub fn log(_msg: &str) { + #[cfg(debug_assertions)] + #[cfg(target_arch = "wasm32")] + { + web_sys::console::log_1(&_msg.into()); + } + #[cfg(debug_assertions)] + #[cfg(not(target_arch = "wasm32"))] + { + println!("{}", _msg); + } +} + + fn main() { #[cfg(feature = "native")] { std::thread::spawn(native::native_camera_thread); + futures::executor::block_on(run_app()); } - run_app(); + #[cfg(target_arch = "wasm32")] + { + #[cfg(debug_assertions)] + console_error_panic_hook::set_once(); + + wasm_bindgen_futures::spawn_local(run_app()); + } } diff --git a/crates/bevy_burn_dino/www/index.html b/crates/bevy_burn_dino/www/index.html index 4e6189d..6652545 100644 --- a/crates/bevy_burn_dino/www/index.html +++ b/crates/bevy_burn_dino/www/index.html @@ -1,7 +1,7 @@ - burn_dino + bevy_burn_dino +
+ diff --git a/src/layers/attention.rs b/src/layers/attention.rs index 9f1a8e8..70508a2 100644 --- a/src/layers/attention.rs +++ b/src/layers/attention.rs @@ -81,7 +81,7 @@ impl Attention { } } - #[allow(non_snake_case)] + #[allow(non_snake_case, clippy::single_range_in_vec_init)] pub fn forward(&self, x: Tensor) -> Tensor { let [B, N, C] = x.shape().dims(); @@ -108,8 +108,6 @@ impl Attention { .reshape([B, N, C]); let x = self.proj.forward(x); - let x = self.proj_drop.forward(x); - - x + self.proj_drop.forward(x) } } diff --git a/src/layers/block.rs b/src/layers/block.rs index 8685846..718e4e7 100644 --- a/src/layers/block.rs +++ b/src/layers/block.rs @@ -68,7 +68,7 @@ impl Block { // self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() // self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() let ls1 = if let Some(layer_scale_config) = &config.layer_scale { - layer_scale_config.init::(&device).into() + layer_scale_config.init::(device).into() } else { None }; @@ -82,7 +82,7 @@ impl Block { .init::(device); let ls2 = if let Some(layer_scale_config) = &config.layer_scale { - layer_scale_config.init::(&device).into() + layer_scale_config.init::(device).into() } else { None }; diff --git a/src/layers/layer_norm.rs b/src/layers/layer_norm.rs index ce34650..9d00244 100644 --- a/src/layers/layer_norm.rs +++ b/src/layers/layer_norm.rs @@ -18,7 +18,7 @@ impl Default for LayerNormConfig { impl LayerNormConfig { pub fn init(&self, device: &B::Device) -> LayerNorm { - LayerNorm::new(device, &self) + LayerNorm::new(device, self) } } diff --git a/src/layers/layer_scale.rs b/src/layers/layer_scale.rs index 1fd21e5..a7f0924 100644 --- a/src/layers/layer_scale.rs +++ b/src/layers/layer_scale.rs @@ -18,7 +18,7 @@ impl Default for LayerScaleConfig { impl LayerScaleConfig { pub fn init(&self, device: &B::Device) -> LayerScale { - LayerScale::new(device, &self) + LayerScale::new(device, self) } } diff --git a/src/layers/mlp.rs b/src/layers/mlp.rs index 85d39fb..47c1789 100644 --- a/src/layers/mlp.rs +++ b/src/layers/mlp.rs @@ -56,7 +56,6 @@ impl Mlp { let x = self.act.forward(x); let x = self.dropout.forward(x); let x = self.fc2.forward(x); - let x = self.dropout.forward(x); - x + self.dropout.forward(x) } } diff --git a/src/model/dino.rs b/src/model/dino.rs index 630a538..f990dc9 100644 --- a/src/model/dino.rs +++ b/src/model/dino.rs @@ -75,7 +75,6 @@ impl DinoVisionTransformerConfig { }, layer_scale: LayerScaleConfig { dim, - ..Default::default() }.into(), ..Default::default() }, @@ -100,7 +99,6 @@ impl DinoVisionTransformerConfig { }, layer_scale: LayerScaleConfig { dim: embedding_dimension, - ..Default::default() }.into(), ..Default::default() }, @@ -125,7 +123,6 @@ impl DinoVisionTransformerConfig { }, layer_scale: LayerScaleConfig { dim: embedding_dimension, - ..Default::default() }.into(), ..Default::default() }, @@ -146,7 +143,6 @@ impl DinoVisionTransformerConfig { }, layer_scale: LayerScaleConfig { dim: embedding_dimension, - ..Default::default() }.into(), ..Default::default() }, @@ -273,13 +269,13 @@ impl DinoVisionTransformer { patch_pos_embed.reshape([1, M, M, dim]).permute([0, 3, 1, 2]), ).permute([0, 2, 3, 1]).reshape([1_i32, -1, dim as i32]); - return Tensor::cat( + Tensor::cat( vec![ class_pos_embed.unsqueeze_dim(0), patch_pos_embed, ], 1, - ); + ) } #[allow(non_snake_case)] @@ -304,9 +300,7 @@ impl DinoVisionTransformer { ); let residual = self.interpolate_pos_encoding(x.clone(), W, H); - let x = x + residual.clone(); - - x + x + residual.clone() } #[allow(non_snake_case)] @@ -367,7 +361,7 @@ impl DinoVisionTransformer { x_norm_clstoken, x_norm_patchtokens, x_prenorm: x, - masks: masks.into(), + masks, } } } diff --git a/src/model/mod.rs b/src/model/mod.rs index 8803275..c887118 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -1,2 +1,3 @@ +#[allow(clippy::too_many_arguments)] pub mod dino; pub mod pca; diff --git a/src/model/pca.rs b/src/model/pca.rs index a834d58..d0456c1 100644 --- a/src/model/pca.rs +++ b/src/model/pca.rs @@ -22,7 +22,7 @@ impl Default for PcaTransformConfig { impl PcaTransformConfig { pub fn init(&self, device: &B::Device) -> PcaTransform { - PcaTransform::new(device, &self) + PcaTransform::new(device, self) } }