From ba89bc26e2de2ebb59de927946c12aa733626bae Mon Sep 17 00:00:00 2001 From: mosure Date: Mon, 25 Nov 2024 20:30:16 -0600 Subject: [PATCH 1/5] feat: web scaffold --- .github/workflows/deploy-pages.yml | 51 ++++++++++++++++++++++++++++++ Cargo.toml | 7 ---- crates/bevy_burn_dino/Cargo.toml | 2 +- crates/bevy_burn_dino/src/main.rs | 40 +++++++++++++++++++---- 4 files changed, 85 insertions(+), 15 deletions(-) create mode 100644 .github/workflows/deploy-pages.yml diff --git a/.github/workflows/deploy-pages.yml b/.github/workflows/deploy-pages.yml new file mode 100644 index 0000000..ba25cb7 --- /dev/null +++ b/.github/workflows/deploy-pages.yml @@ -0,0 +1,51 @@ +name: deploy github pages + +on: + push: + branches: + - main + workflow_dispatch: + +permissions: + contents: write + +env: + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + + +jobs: + deploy: + runs-on: macos-latest + + steps: + - name: checkout repository + uses: actions/checkout@v3 + + - name: setup nightly rust toolchain with caching + uses: brndnmtthws/rust-action@v1 + with: + toolchain: nightly + components: rustfmt, clippy + enable-sccache: "false" + + - name: install wasm32-unknown-unknown + run: rustup target add wasm32-unknown-unknown + + - name: install wasm-bindgen-cli + run: cargo install wasm-bindgen-cli + + - name: build wasm artifacts + run: cargo build -p bevy_burn_dino --target wasm32-unknown-unknown --release --no-default-features --features "web" + + - name: generate bindings with wasm-bindgen + run: wasm-bindgen --out-dir ./crates/bevy_burn_dino/www/out/ --target web ./target/wasm32-unknown-unknown/release/bevy_burn_dino.wasm + + - name: copy assets + run: mkdir -p ./crates/bevy_burn_dino/www/assets && cp -r ./crates/bevy_burn_dino/assets/* ./crates/bevy_burn_dino/www/assets/ + + - name: deploy to github pages + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: ./www + branch: www diff --git a/Cargo.toml b/Cargo.toml index 5d3bf5f..7999472 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,15 +51,8 @@ serde = { version = "1.0", optional = true } [dependencies.burn] version = "0.15" features = [ - "autodiff", "autotune", - "dataset", "fusion", - "metrics", - "ndarray", - "network", - "train", - "tui", "wgpu", ] diff --git a/crates/bevy_burn_dino/Cargo.toml b/crates/bevy_burn_dino/Cargo.toml index 7142784..ab5025d 100644 --- a/crates/bevy_burn_dino/Cargo.toml +++ b/crates/bevy_burn_dino/Cargo.toml @@ -29,7 +29,7 @@ perftest = [] [dependencies] # bevy_args = "1.6" bevy_args = { git = "https://github.com/mosure/bevy_args.git", branch = "burn" } -burn_dino = { path = "../../" } +burn_dino = { path = "../../", default-features = false } clap = { version = "4.5", features = ["derive"] } futures-intrusive = "0.5" image = { version = "0.25.2", default-features = false, features = ["png"] } diff --git a/crates/bevy_burn_dino/src/main.rs b/crates/bevy_burn_dino/src/main.rs index cbad70e..96b9e49 100644 --- a/crates/bevy_burn_dino/src/main.rs +++ b/crates/bevy_burn_dino/src/main.rs @@ -316,6 +316,39 @@ mod native { } +#[cfg(feature = "web")] +mod web { + use std::cell::RefCell; + + use image::RgbImage; + use wasm_bindgen::prelude::*; + pub use web_sys::ImageData; + + thread_local! { + static SAMPLE_RECEIVER: RefCell> = RefCell::new(None); + } + + #[wasm_bindgen] + pub fn frame_input(js_image_data: ImageData) { + // Get the dimensions of the image + let width = js_image_data.width() as u32; + let height = js_image_data.height() as u32; + + // Get the raw pixel data + let data = js_image_data.data(); + + // Convert the raw data into an RgbImage + let rgb_image = RgbImage::from_raw(width, height, data.to_vec()) + .expect("failed to create RgbImage"); + + // Store the image in the global receiver + SAMPLE_RECEIVER.with(|receiver| { + *receiver.borrow_mut() = Some(rgb_image); + }); + } +} + + #[derive(Resource)] struct DinoModel { config: DinoVisionTransformerConfig, @@ -363,13 +396,6 @@ fn process_frames( } -// TODO: web-sys ffi -#[cfg(feature = "web")] -fn frame_input() { - -} - - fn setup_ui( mut commands: Commands, dino: Res>, From 3addedf84e46f33fc92c12ab4057ebc245a6a35a Mon Sep 17 00:00:00 2001 From: mosure Date: Mon, 25 Nov 2024 21:35:16 -0600 Subject: [PATCH 2/5] docs: todo --- crates/bevy_burn_dino/src/main.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/bevy_burn_dino/src/main.rs b/crates/bevy_burn_dino/src/main.rs index 96b9e49..c6d81f4 100644 --- a/crates/bevy_burn_dino/src/main.rs +++ b/crates/bevy_burn_dino/src/main.rs @@ -184,6 +184,7 @@ fn to_image( } +// TODO: benchmark process_frame fn process_frame( input: RgbImage, dino: Res>, From 21268c5404cc949ffb99fb70b92f0a00891971a0 Mon Sep 17 00:00:00 2001 From: mosure Date: Tue, 26 Nov 2024 00:14:07 -0600 Subject: [PATCH 3/5] feat: seems to build fine but freezes chrome --- .github/workflows/deploy-pages.yml | 6 ++ Cargo.toml | 4 ++ crates/bevy_burn_dino/Cargo.toml | 6 +- crates/bevy_burn_dino/src/main.rs | 83 ++++++++++++++++++++-------- crates/bevy_burn_dino/www/index.html | 27 ++++++++- 5 files changed, 102 insertions(+), 24 deletions(-) diff --git a/.github/workflows/deploy-pages.yml b/.github/workflows/deploy-pages.yml index ba25cb7..1b7a942 100644 --- a/.github/workflows/deploy-pages.yml +++ b/.github/workflows/deploy-pages.yml @@ -35,9 +35,15 @@ jobs: - name: install wasm-bindgen-cli run: cargo install wasm-bindgen-cli + # - name: install wasm-opt + # run: cargo install wasm-opt --locked + - name: build wasm artifacts run: cargo build -p bevy_burn_dino --target wasm32-unknown-unknown --release --no-default-features --features "web" + # - name: optimize wasm artifacts + # run: wasm-opt -O -ol 100 -s 100 -o ./target/wasm32-unknown-unknown/release/bevy_burn_dino_opt.wasm ./target/wasm32-unknown-unknown/release/bevy_burn_dino.wasm + - name: generate bindings with wasm-bindgen run: wasm-bindgen --out-dir ./crates/bevy_burn_dino/www/out/ --target web ./target/wasm32-unknown-unknown/release/bevy_burn_dino.wasm diff --git a/Cargo.toml b/Cargo.toml index 7999472..7737546 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,15 +91,19 @@ opt-level = 3 opt-level = 1 [profile.release] +strip = "symbols" lto = "thin" codegen-units = 1 opt-level = 3 +panic = "abort" [profile.wasm-release] +strip = "symbols" inherits = "release" opt-level = "z" lto = "fat" codegen-units = 1 +panic = "abort" [lib] diff --git a/crates/bevy_burn_dino/Cargo.toml b/crates/bevy_burn_dino/Cargo.toml index ab5025d..9441c6f 100644 --- a/crates/bevy_burn_dino/Cargo.toml +++ b/crates/bevy_burn_dino/Cargo.toml @@ -31,6 +31,7 @@ perftest = [] bevy_args = { git = "https://github.com/mosure/bevy_args.git", branch = "burn" } burn_dino = { path = "../../", default-features = false } clap = { version = "4.5", features = ["derive"] } +futures = "0.3" futures-intrusive = "0.5" image = { version = "0.25.2", default-features = false, features = ["png"] } nokhwa = { version = "0.10", features = ["input-native", "output-threaded"], optional = true } @@ -51,8 +52,11 @@ features = [ "bevy_text", "bevy_ui", "bevy_winit", + "custom_cursor", "default_font", "png", + # "webgpu", + "x11", ] [dependencies.burn] @@ -82,5 +86,5 @@ features = [ [[bin]] -name = "ui" +name = "bevy_burn_dino" path = "src/main.rs" diff --git a/crates/bevy_burn_dino/src/main.rs b/crates/bevy_burn_dino/src/main.rs index c6d81f4..9c9b8dd 100644 --- a/crates/bevy_burn_dino/src/main.rs +++ b/crates/bevy_burn_dino/src/main.rs @@ -30,9 +30,10 @@ use bevy_args::{ }; use burn::{ prelude::*, - backend::wgpu::Wgpu, + backend::wgpu::{init_async, AutoGraphicsApi, Wgpu}, record::{FullPrecisionSettings, NamedMpkBytesRecorder, Recorder}, }; +use futures::executor::block_on; use image::{ DynamicImage, ImageBuffer, @@ -326,23 +327,19 @@ mod web { pub use web_sys::ImageData; thread_local! { - static SAMPLE_RECEIVER: RefCell> = RefCell::new(None); + pub static SAMPLE_RECEIVER: RefCell> = RefCell::new(None); } #[wasm_bindgen] pub fn frame_input(js_image_data: ImageData) { - // Get the dimensions of the image let width = js_image_data.width() as u32; let height = js_image_data.height() as u32; - // Get the raw pixel data let data = js_image_data.data(); - // Convert the raw data into an RgbImage let rgb_image = RgbImage::from_raw(width, height, data.to_vec()) .expect("failed to create RgbImage"); - // Store the image in the global receiver SAMPLE_RECEIVER.with(|receiver| { *receiver.borrow_mut() = Some(rgb_image); }); @@ -396,6 +393,28 @@ fn process_frames( } } +#[cfg(feature = "web")] +fn process_frames( + dino_model: Res>, + pca_transform: Res>, + pca_features_handle: Res, + images: ResMut>, +) { + web::SAMPLE_RECEIVER.with(|receiver| { + let mut receiver = receiver.borrow_mut(); + + if let Some(image) = receiver.take() { + process_frame( + image, + dino_model, + pca_transform, + &pca_features_handle.image, + images, + ); + } + }); +} + fn setup_ui( mut commands: Commands, @@ -424,20 +443,6 @@ fn setup_ui( ..default() }) .with_children(|builder| { - // TODO: view input - // builder.spawn(ImageBundle { - // style: Style { - // width: Val::Percent(100.0), - // height: Val::Percent(100.0), - // ..default() - // }, - // image: UiImage { - // texture: foreground, - // ..default() - // }, - // ..default() - // }); - builder.spawn(UiImage { image: pca_image.image.clone(), image_mode: NodeImageMode::Stretch, @@ -454,13 +459,15 @@ pub fn viewer_app() -> App { let mut app = App::new(); app.insert_resource(args.clone()); + let title = "bevy_burn_dino".to_string(); + #[cfg(target_arch = "wasm32")] let primary_window = Some(Window { // fit_canvas_to_parent: true, canvas: Some("#bevy".to_string()), mode: bevy::window::WindowMode::Windowed, prevent_default_event_handling: true, - title: args.name.clone(), + title, #[cfg(feature = "perftest")] present_mode: bevy::window::PresentMode::AutoNoVsync, @@ -475,7 +482,7 @@ pub fn viewer_app() -> App { mode: bevy::window::WindowMode::Windowed, prevent_default_event_handling: false, resolution: (1024.0, 1024.0).into(), - title: "bevy_burn_dino".to_string(), + title, #[cfg(feature = "perftest")] present_mode: bevy::window::PresentMode::AutoNoVsync, @@ -582,19 +589,29 @@ fn fps_update_system( } fn run_app() { + log("running app..."); + // TODO: move model load to startup/async task let device = Default::default(); + block_on(init_async::(&device, Default::default())); + + log("device created"); + let config = DinoVisionTransformerConfig { ..DinoVisionTransformerConfig::vits(None, None) // TODO: supply image size fron config }; let dino = load_model::(&config, &device); + log("dino model loaded"); + let pca_config = PcaTransformConfig::new( config.embedding_dimension, 3, ); let pca_transform = load_pca_model::(&pca_config, &device); + log("pca model loaded"); + let mut app = viewer_app(); app.init_resource::(); @@ -610,15 +627,37 @@ fn run_app() { app.add_systems(Startup, setup_ui); app.add_systems(Update, process_frames); + log("running bevy app..."); + app.run(); } +pub fn log(_msg: &str) { + #[cfg(debug_assertions)] + #[cfg(target_arch = "wasm32")] + { + web_sys::console::log_1(&_msg.into()); + } + #[cfg(debug_assertions)] + #[cfg(not(target_arch = "wasm32"))] + { + println!("{}", _msg); + } +} + + fn main() { #[cfg(feature = "native")] { std::thread::spawn(native::native_camera_thread); } + #[cfg(debug_assertions)] + #[cfg(target_arch = "wasm32")] + { + console_error_panic_hook::set_once(); + } + run_app(); } diff --git a/crates/bevy_burn_dino/www/index.html b/crates/bevy_burn_dino/www/index.html index 4e6189d..81ab0c6 100644 --- a/crates/bevy_burn_dino/www/index.html +++ b/crates/bevy_burn_dino/www/index.html @@ -10,15 +10,39 @@ padding: 0; overflow: hidden; } + .loading { + position: absolute; + top: 50%; + left: 50%; + margin: -15px 0 0 -15px; + height: 30px; + width: 30px; + border: 2px solid #ddd; + border-left-color: #009688; + border-radius: 30px; + -webkit-animation: animation-rotate 950ms cubic-bezier(.64,2,.56,.6) infinite; + animation: animation-rotate 950ms cubic-bezier(.64,2,.56,.6) infinite; + } + @-webkit-keyframes animation-rotate { + 100% { + -webkit-transform: rotate(360deg); + } + } + @keyframes animation-rotate { + 100% { + -webkit-transform: rotate(360deg); + transform: rotate(360deg); + } + } +
From 96daa67b6f34bbc37dd52ebc8a87e024fbf6e03c Mon Sep 17 00:00:00 2001 From: mosure Date: Tue, 26 Nov 2024 14:04:38 -0600 Subject: [PATCH 4/5] feat: web demo --- .github/workflows/build.yml | 2 +- .github/workflows/clippy.yml | 2 +- .github/workflows/test.yml | 2 +- .gitignore | 1 + Cargo.toml | 35 ++++- crates/bevy_burn_dino/Cargo.toml | 12 +- crates/bevy_burn_dino/src/main.rs | 185 +++++++++++++++++---------- crates/bevy_burn_dino/www/index.html | 60 ++++++++- 8 files changed, 221 insertions(+), 78 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e58f769..93c642e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - os: [windows-latest, macos-latest, macos-14] + os: [windows-latest, macos-latest] rust-toolchain: - nightly diff --git a/.github/workflows/clippy.yml b/.github/workflows/clippy.yml index 91746aa..fe9ac3e 100644 --- a/.github/workflows/clippy.yml +++ b/.github/workflows/clippy.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - os: [windows-latest, macos-latest, macos-14] + os: [windows-latest, macos-latest] rust-toolchain: - nightly diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0c519b7..7f47cbf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - os: [windows-latest, macos-latest, macos-14] + os: [windows-latest, macos-latest] rust-toolchain: - nightly diff --git a/.gitignore b/.gitignore index b6c2102..a9168da 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ Cargo.lock screenshots/ headless_output/ www/assets/ +crates/bevy_burn_dino/www/assets/ assets/images/ assets/models/ diff --git a/Cargo.toml b/Cargo.toml index 7737546..2ca594e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,8 @@ resolver = "2" [features] default = ["import"] + +benchmark = ["burn/autotune", "burn/fusion"] import = ["bevy_args", "burn-import", "clap", "serde"] @@ -50,12 +52,41 @@ serde = { version = "1.0", optional = true } [dependencies.burn] version = "0.15" +default-features = false features = [ - "autotune", - "fusion", + # "autotune", + # "fusion", + "std", "wgpu", ] +[dependencies.burn-wgpu] +version = "0.15" +default-features = false +features = [ + "fusion", + "std", + # "template", +] + +[dependencies.cubecl] +version = "0.3" +default-features = false +features = [ + "linalg", + "std", + # "template", +] + +[dependencies.cubecl-runtime] +version = "0.3" +default-features = false +features = [ + "std", + "channel-mpsc", + # "template", +] + [target.'cfg(target_arch = "wasm32")'.dependencies] console_error_panic_hook = "0.1" diff --git a/crates/bevy_burn_dino/Cargo.toml b/crates/bevy_burn_dino/Cargo.toml index 9441c6f..1ec948d 100644 --- a/crates/bevy_burn_dino/Cargo.toml +++ b/crates/bevy_burn_dino/Cargo.toml @@ -18,7 +18,7 @@ categories = [ [features] default = ["native"] -native = ["nokhwa"] +native = ["burn/autotune", "futures", "nokhwa"] web = [] editor = [] @@ -31,7 +31,7 @@ perftest = [] bevy_args = { git = "https://github.com/mosure/bevy_args.git", branch = "burn" } burn_dino = { path = "../../", default-features = false } clap = { version = "4.5", features = ["derive"] } -futures = "0.3" +futures = { version = "0.3", optional = true } futures-intrusive = "0.5" image = { version = "0.25.2", default-features = false, features = ["png"] } nokhwa = { version = "0.10", features = ["input-native", "output-threaded"], optional = true } @@ -61,9 +61,12 @@ features = [ [dependencies.burn] version = "0.15" +default-features = false features = [ - "autotune", - "fusion", + # "autotune", + # "fusion", + "std", + # "template", "wgpu", ] @@ -71,6 +74,7 @@ features = [ [target.'cfg(target_arch = "wasm32")'.dependencies] console_error_panic_hook = "0.1" wasm-bindgen = "0.2" +wasm-bindgen-futures = "0.4" [dependencies.web-sys] diff --git a/crates/bevy_burn_dino/src/main.rs b/crates/bevy_burn_dino/src/main.rs index 9c9b8dd..273536f 100644 --- a/crates/bevy_burn_dino/src/main.rs +++ b/crates/bevy_burn_dino/src/main.rs @@ -7,6 +7,7 @@ use bevy::{ DiagnosticsStore, FrameTimeDiagnosticsPlugin, }, + ecs::{system::SystemState, world::CommandQueue}, render::{ render_asset::RenderAssetUsages, render_resource::{ @@ -21,6 +22,7 @@ use bevy::{ }, RenderPlugin, }, + tasks::{block_on, futures_lite::future, AsyncComputeTaskPool, Task}, }; use bevy_args::{ parse_args, @@ -33,7 +35,6 @@ use burn::{ backend::wgpu::{init_async, AutoGraphicsApi, Wgpu}, record::{FullPrecisionSettings, NamedMpkBytesRecorder, Recorder}, }; -use futures::executor::block_on; use image::{ DynamicImage, ImageBuffer, @@ -74,7 +75,7 @@ pub struct DinoImportConfig { #[arg(long, default_value = "true")] pub press_esc_to_close: bool, - #[arg(long, default_value = "true")] + #[arg(long, default_value = "false")] pub show_fps: bool, #[arg(long, default_value = "518")] @@ -89,7 +90,7 @@ impl Default for DinoImportConfig { Self { pca_only: true, press_esc_to_close: true, - show_fps: true, + show_fps: false, // TODO: display inference fps (UI fps is decoupled via async compute pool) inference_height: 518, inference_width: 518, } @@ -164,7 +165,7 @@ fn preprocess_image( } -fn to_image( +async fn to_image( image: Tensor, upsample_height: u32, upsample_width: u32, @@ -172,7 +173,7 @@ fn to_image( let height = image.shape().dims[1]; let width = image.shape().dims[2]; - let image = image.to_data().to_vec::().unwrap(); + let image = image.to_data_async().await.to_vec::().unwrap(); let image = ImageBuffer::, Vec>::from_raw( width as u32, height as u32, @@ -186,16 +187,16 @@ fn to_image( // TODO: benchmark process_frame -fn process_frame( +async fn process_frame( input: RgbImage, - dino: Res>, - pca_transform: Res>, - image_handle: &Handle, - mut images: ResMut>, -) { - let input_tensor: Tensor = preprocess_image(input, &dino.config, &dino.device); + dino_config: DinoVisionTransformerConfig, + dino_model: Arc>>, + pca_model: Arc>>, + device: B::Device, +) -> Vec { + let input_tensor: Tensor = preprocess_image(input, &dino_config, &device); - let model = dino.model.lock().unwrap(); + let model = dino_model.lock().unwrap(); let dino_features = model.forward(input_tensor.clone(), None).x_norm_patchtokens; let batch = dino_features.shape().dims[0]; @@ -206,15 +207,17 @@ fn process_frame( let x = dino_features.reshape([n_samples, embedding_dim]); - let pca_transform = pca_transform.model.lock().unwrap(); + let pca_transform = pca_model.lock().unwrap(); let mut pca_features = pca_transform.forward(x.clone()); // pca min-max scaling for i in 0..3 { let slice = pca_features.clone().slice([0..n_samples, i..i+1]); - let slice_min = slice.clone().min().into_scalar(); - let slice_max = slice.clone().max().into_scalar(); - let scaled = slice.sub_scalar(slice_min).div_scalar(slice_max - slice_min); + let slice_min = slice.clone().min(); + let slice_max = slice.clone().max(); + let scaled = slice + .sub(slice_min.clone().unsqueeze()) + .div((slice_max - slice_min).unsqueeze()); pca_features = pca_features.slice_assign( [0..n_samples, i..i+1], @@ -226,13 +229,11 @@ fn process_frame( let pca_features = to_image( pca_features, - dino.config.image_size as u32, - dino.config.image_size as u32, - ); + dino_config.image_size as u32, + dino_config.image_size as u32, + ).await; - // TODO: share wgpu io between bevy/burn - let existing_image = images.get_mut(image_handle).unwrap(); - existing_image.data = pca_features.into_raw(); + pca_features.into_raw() } @@ -244,6 +245,7 @@ mod native { mpsc::{ self, Sender, + SyncSender, Receiver, TryRecvError, }, @@ -264,7 +266,7 @@ mod native { use once_cell::sync::OnceCell; pub static SAMPLE_RECEIVER: OnceCell>>> = OnceCell::new(); - pub static SAMPLE_SENDER: OnceCell> = OnceCell::new(); + pub static SAMPLE_SENDER: OnceCell> = OnceCell::new(); pub static APP_RUN_RECEIVER: OnceCell>>> = OnceCell::new(); pub static APP_RUN_SENDER: OnceCell> = OnceCell::new(); @@ -273,7 +275,7 @@ mod native { let ( sample_sender, sample_receiver, - ) = mpsc::channel(); + ) = mpsc::sync_channel(1); SAMPLE_RECEIVER.set(Arc::new(Mutex::new(sample_receiver))).unwrap(); SAMPLE_SENDER.set(sample_sender).unwrap(); @@ -322,24 +324,27 @@ mod native { mod web { use std::cell::RefCell; - use image::RgbImage; + use image::{ + DynamicImage, + RgbImage, + RgbaImage, + }; use wasm_bindgen::prelude::*; - pub use web_sys::ImageData; thread_local! { pub static SAMPLE_RECEIVER: RefCell> = RefCell::new(None); } #[wasm_bindgen] - pub fn frame_input(js_image_data: ImageData) { - let width = js_image_data.width() as u32; - let height = js_image_data.height() as u32; - - let data = js_image_data.data(); - - let rgb_image = RgbImage::from_raw(width, height, data.to_vec()) + pub fn frame_input(pixel_data: &[u8], width: u32, height: u32) { + let rgba_image = RgbaImage::from_raw(width, height, pixel_data.to_vec()) .expect("failed to create RgbImage"); + // TODO: perform video element -> burn's webgpu texture conversion directly + // TODO: perform this conversion from tensors + let dynamic_image = DynamicImage::ImageRgba8(rgba_image); + let rgb_image: RgbImage = dynamic_image.to_rgb8(); + SAMPLE_RECEIVER.with(|receiver| { *receiver.borrow_mut() = Some(rgb_image); }); @@ -351,7 +356,7 @@ mod web { struct DinoModel { config: DinoVisionTransformerConfig, device: B::Device, - model: Arc::>>, + model: Arc>>, } #[derive(Resource)] @@ -365,13 +370,12 @@ struct PcaFeatures { image: Handle, } + +#[derive(Component)] +struct ProcessImage(Task); + #[cfg(feature = "native")] -fn process_frames( - dino_model: Res>, - pca_transform: Res>, - pca_features_handle: Res, - images: ResMut>, -) { +fn receive_image() -> Option { let receiver = native::SAMPLE_RECEIVER.get().unwrap(); let mut last_image = None; @@ -382,37 +386,78 @@ fn process_frames( } } - if let Some(image) = last_image { - process_frame( - image, - dino_model, - pca_transform, - &pca_features_handle.image, - images, - ); - } + last_image } #[cfg(feature = "web")] +fn receive_image() -> Option { + web::SAMPLE_RECEIVER.with(|receiver| { + receiver.borrow_mut().take() + }) +} + fn process_frames( + mut commands: Commands, dino_model: Res>, pca_transform: Res>, pca_features_handle: Res, - images: ResMut>, + active_tasks: Query<&ProcessImage>, ) { - web::SAMPLE_RECEIVER.with(|receiver| { - let mut receiver = receiver.borrow_mut(); + // TODO: move to config + // TODO: fix multiple in flight deadlock + let inference_max_in_flight = 1; + if active_tasks.iter().count() >= inference_max_in_flight { + return; + } + + if let Some(image) = receive_image() { + let thread_pool = AsyncComputeTaskPool::get(); + let entity = commands.spawn_empty().id(); + + let device = dino_model.device.clone(); + let dino_config = dino_model.config.clone(); + let dino_model = dino_model.model.clone(); + let image_handle = pca_features_handle.image.clone(); + let pca_model = pca_transform.model.clone(); - if let Some(image) = receiver.take() { - process_frame( + let task = thread_pool.spawn(async move { + let img_data = process_frame( image, + dino_config, dino_model, - pca_transform, - &pca_features_handle.image, - images, - ); + pca_model, + device, + ).await; + + let mut command_queue = CommandQueue::default(); + command_queue.push(move |world: &mut World| { + let mut system_state = SystemState::< + ResMut>, + >::new(world); + let mut images = system_state.get_mut(world); + + // TODO: share wgpu io between bevy/burn + let existing_image = images.get_mut(&image_handle).unwrap(); + existing_image.data = img_data; + + world + .entity_mut(entity) + .remove::(); + }); + + command_queue + }); + + commands.entity(entity).insert(ProcessImage(task)); + } +} + +fn handle_tasks(mut commands: Commands, mut transform_tasks: Query<&mut ProcessImage>) { + for mut task in &mut transform_tasks { + if let Some(mut commands_queue) = block_on(future::poll_once(&mut task.0)) { + commands.append(&mut commands_queue); } - }); + } } @@ -588,12 +633,11 @@ fn fps_update_system( } } -fn run_app() { +async fn run_app() { log("running app..."); - // TODO: move model load to startup/async task let device = Default::default(); - block_on(init_async::(&device, Default::default())); + init_async::(&device, Default::default()).await; log("device created"); @@ -625,7 +669,13 @@ fn run_app() { }); app.add_systems(Startup, setup_ui); - app.add_systems(Update, process_frames); + app.add_systems( + Update, + ( + handle_tasks, + process_frames, + ), + ); log("running bevy app..."); @@ -651,13 +701,14 @@ fn main() { #[cfg(feature = "native")] { std::thread::spawn(native::native_camera_thread); + futures::executor::block_on(run_app()); } - #[cfg(debug_assertions)] #[cfg(target_arch = "wasm32")] { + #[cfg(debug_assertions)] console_error_panic_hook::set_once(); - } - run_app(); + wasm_bindgen_futures::spawn_local(run_app()); + } } diff --git a/crates/bevy_burn_dino/www/index.html b/crates/bevy_burn_dino/www/index.html index 81ab0c6..6652545 100644 --- a/crates/bevy_burn_dino/www/index.html +++ b/crates/bevy_burn_dino/www/index.html @@ -1,7 +1,7 @@ - burn_dino + bevy_burn_dino