diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8ceb28c..db6a205 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,11 +30,36 @@ jobs: with: toolchain: ${{ matrix.rust-toolchain }} + + - name: Install ONNX Runtime on Windows + if: matrix.os == 'windows-latest' + run: | + Invoke-WebRequest -Uri "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-win-x64-1.17.1.zip" -OutFile "onnxruntime.zip" + Expand-Archive -Path "onnxruntime.zip" -DestinationPath "$env:RUNNER_TEMP" + echo "ONNXRUNTIME_DIR=$env:RUNNER_TEMP\onnxruntime-win-x64-1.17.1" | Out-File -Append -Encoding ascii $env:GITHUB_ENV + + - name: Install ONNX Runtime on macOS + if: matrix.os == 'macos-latest' + run: | + curl -L "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-osx-x86_64-1.17.1.tgz" -o "onnxruntime.tgz" + mkdir -p $HOME/onnxruntime + tar -xzf onnxruntime.tgz -C $HOME/onnxruntime + echo "ONNXRUNTIME_DIR=$HOME/onnxruntime/onnxruntime-osx-x86_64-1.17.1" >> $GITHUB_ENV + + + - name: Set ONNX Runtime library path for macOS + if: matrix.os == 'macos-latest' + run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/libonnxruntime.dylib" >> $GITHUB_ENV + + - name: Set ONNX Runtime library path for Windows + if: matrix.os == 'windows-latest' + run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/onnxruntime.dll" >> $GITHUB_ENV + + - name: lint run: cargo clippy -- -Dwarnings - name: build - run: cargo build - - # - name: build (web) - # run: cargo build --example=minimal --target wasm32-unknown-unknown --release + run: cargo build --features "ort/load-dynamic" + env: + ORT_DYLIB_PATH: ${{ env.ORT_DYLIB_PATH }} diff --git a/Cargo.toml b/Cargo.toml index c5efd7b..22a0071 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "bevy_ort" description = "bevy ort (onnxruntime) plugin" -version = "0.2.0" +version = "0.4.0" edition = "2021" authors = ["mosure "] license = "MIT" @@ -38,7 +38,6 @@ modnet = [] bevy_args = "1.3" image = "0.24" ndarray = "0.15" -ort = "2.0.0-alpha.4" thiserror = "1.0" tokio = { version = "1.36", features = ["full"] } @@ -57,6 +56,14 @@ features = [ ] +[dependencies.ort] +version = "2.0.0-alpha.4" +default-features = true +features = [ + +] + + [profile.dev.package."*"] opt-level = 3 diff --git a/README.md b/README.md index 4cebfea..a9c0279 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ a bevy plugin for the [ort](https://docs.rs/ort/latest/ort/) library ![person](assets/person.png) ![mask](assets/mask.png) + *> modnet inference example* @@ -20,6 +21,7 @@ a bevy plugin for the [ort](https://docs.rs/ort/latest/ort/) library - [X] load ONNX models as ORT session assets - [X] initialize ORT with default execution providers - [X] modnet bevy image <-> ort tensor IO (with feature `modnet`) +- [X] batched modnet preprocessing - [ ] compute task pool inference scheduling @@ -32,8 +34,8 @@ use bevy_ort::{ BevyOrtPlugin, inputs, models::modnet::{ - image_to_modnet_input, - modnet_output_to_luma_image, + images_to_modnet_input, + modnet_output_to_luma_images, }, Onnx, }; @@ -81,7 +83,7 @@ fn inference( } let image = images.get(&modnet.input).expect("failed to get image asset"); - let input = image_to_modnet_input(image); + let input = images_to_modnet_input(vec![&image]); let output: Result, String> = (|| { let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?; @@ -95,7 +97,7 @@ fn inference( Ok(output) => { let output_value: &ort::Value = output.get("output").unwrap(); - let mask_image = modnet_output_to_luma_image(output_value); + let mask_image = modnet_output_to_luma_images(output_value).pop().unwrap(); let mask_image = images.add(mask_image); commands.spawn(NodeBundle { @@ -139,6 +141,13 @@ fn inference( cargo run ``` +use an accelerated execution provider: +- windows - `cargo run --features ort/cuda` or `cargo run --features ort/openvino` +- macos - `cargo run --features ort/coreml` +- linux - `cargo run --features ort/tensorrt` or `cargo run --features ort/openvino` + +> see complete list of ort features here: https://github.com/pykeio/ort/blob/0aec4030a5f3470e4ee6c6f4e7e52d4e495ec27a/Cargo.toml#L54 + > note: if you use `pip install onnxruntime`, you may need to run `ORT_STRATEGY=system cargo run`, see: https://docs.rs/ort/latest/ort/#how-to-get-binaries diff --git a/src/models/modnet.rs b/src/models/modnet.rs index 4900494..7c9de00 100644 --- a/src/models/modnet.rs +++ b/src/models/modnet.rs @@ -1,89 +1,114 @@ -use std::cmp::{ - max, - min, -}; - use bevy::{prelude::*, render::render_asset::RenderAssetUsages}; use image::{DynamicImage, GenericImageView, imageops::FilterType, ImageBuffer, Luma, RgbImage}; -use ndarray::{Array, Array4, ArrayView4, Axis}; +use ndarray::{Array, Array4, ArrayView4, Axis, s}; -pub fn modnet_output_to_luma_image( +pub fn modnet_output_to_luma_images( output_value: &ort::Value, -) -> Image { +) -> Vec { let tensor: ort::Tensor = output_value.extract_tensor::().unwrap(); let data = tensor.view(); let shape = data.shape(); + let batch_size = shape[0]; let width = shape[3]; let height = shape[2]; - let tensor_data = ArrayView4::from_shape((1, 1, height, width), data.as_slice().unwrap()) + let tensor_data = ArrayView4::from_shape((batch_size, 1, height, width), data.as_slice().unwrap()) .expect("failed to create ArrayView4 from shape and data"); - let mut imgbuf = ImageBuffer::, Vec>::new(width as u32, height as u32); + let mut images = Vec::new(); + + for i in 0..batch_size { + let mut imgbuf = ImageBuffer::, Vec>::new(width as u32, height as u32); - for y in 0..height { - for x in 0..width { - let pixel_value = tensor_data[(0, 0, y, x)]; - let pixel_value = (pixel_value.clamp(0.0, 1.0) * 255.0) as u8; - imgbuf.put_pixel(x as u32, y as u32, Luma([pixel_value])); + for y in 0..height { + for x in 0..width { + let pixel_value = tensor_data[(i, 0, y, x)]; + let pixel_value = (pixel_value.clamp(0.0, 1.0) * 255.0) as u8; + imgbuf.put_pixel(x as u32, y as u32, Luma([pixel_value])); + } } - } - let dyn_img = DynamicImage::ImageLuma8(imgbuf); + let dyn_img = DynamicImage::ImageLuma8(imgbuf); + + images.push(Image::from_dynamic(dyn_img, false, RenderAssetUsages::all())); + } - Image::from_dynamic(dyn_img, false, RenderAssetUsages::all()) + images } -pub fn image_to_modnet_input( - image: &Image, +pub fn images_to_modnet_input( + images: Vec<&Image>, + max_size: Option<(u32, u32)>, ) -> Array4 { - assert_eq!(image.texture_descriptor.format, bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb); + if images.is_empty() { + panic!("no images provided"); + } let ref_size = 512; - let ( - x_scale, - y_scale, - ) = get_scale_factor( - image.height(), - image.width(), - ref_size, - ); - - let resized_image = resize_image( - &image.clone().try_into_dynamic().unwrap(), - x_scale, - y_scale, - ); - - image_to_ndarray(&resized_image) + let &first_image = images.first().unwrap(); + + let image = first_image.to_owned(); + + println!("image: {:?}", image.size()); + + let (x_scale, y_scale) = get_scale_factor(image.height(), image.width(), ref_size, max_size); + let resized_image = resize_image(&image.try_into_dynamic().unwrap(), x_scale, y_scale); + let first_image_ndarray = image_to_ndarray(&resized_image); + + println!("scale_factor: {:?}", (x_scale, y_scale)); + println!("first_image_ndarray: {:?}", first_image_ndarray.dim()); + + let single_image_shape = first_image_ndarray.dim(); + let n_images = images.len(); + let batch_shape = (n_images, single_image_shape.1, single_image_shape.2, single_image_shape.3); + + let mut aggregate = Array4::::zeros(batch_shape); + + for (i, &image) in images.iter().enumerate() { + let image = image.to_owned(); + let (x_scale, y_scale) = get_scale_factor(image.height(), image.width(), ref_size, max_size); + let resized_image = resize_image(&image.try_into_dynamic().unwrap(), x_scale, y_scale); + let image_ndarray = image_to_ndarray(&resized_image); + + let slice = s![i, .., .., ..]; + aggregate.slice_mut(slice).assign(&image_ndarray.index_axis_move(Axis(0), 0)); + } + + aggregate } -fn get_scale_factor(im_h: u32, im_w: u32, ref_size: u32) -> (f32, f32) { - let mut im_rh; - let mut im_rw; +fn get_scale_factor(im_h: u32, im_w: u32, ref_size: u32, max_size: Option<(u32, u32)>) -> (f32, f32) { + // Calculate the scale factor based on the maximum size constraints + let scale_factor_max = max_size.map_or(1.0, |(max_w, max_h)| { + f32::min(max_w as f32 / im_w as f32, max_h as f32 / im_h as f32) + }); - if max(im_h, im_w) < ref_size || min(im_h, im_w) > ref_size { - if im_w >= im_h { - im_rh = ref_size; - im_rw = (im_w as f32 / im_h as f32 * ref_size as f32) as u32; - } else { - im_rw = ref_size; - im_rh = (im_h as f32 / im_w as f32 * ref_size as f32) as u32; - } + // Calculate the target dimensions after applying the max scale factor (clipping to max_size) + let (target_h, target_w) = ((im_h as f32 * scale_factor_max).round() as u32, (im_w as f32 * scale_factor_max).round() as u32); + + // Calculate the scale factor to fit within the reference size, considering the target dimensions + let (scale_factor_ref_w, scale_factor_ref_h) = if std::cmp::max(target_h, target_w) < ref_size { + let scale_factor = ref_size as f32 / std::cmp::max(target_h, target_w) as f32; + (scale_factor, scale_factor) } else { - im_rh = im_h; - im_rw = im_w; - } + (1.0, 1.0) // Do not upscale if target dimensions are within reference size + }; + + // Calculate the final scale factor as the minimum of the max scale factor and the reference scale factor + let final_scale_w = f32::min(scale_factor_max, scale_factor_ref_w); + let final_scale_h = f32::min(scale_factor_max, scale_factor_ref_h); - im_rw = im_rw - im_rw % 32; - im_rh = im_rh - im_rh % 32; + // Adjust dimensions to ensure they are multiples of 32 + let final_w = ((im_w as f32 * final_scale_w).round() as u32) - ((im_w as f32 * final_scale_w).round() as u32) % 32; + let final_h = ((im_h as f32 * final_scale_h).round() as u32) - ((im_h as f32 * final_scale_h).round() as u32) % 32; - (im_rw as f32 / im_w as f32, im_rh as f32 / im_h as f32) + // Return the scale factors based on the original image dimensions + (final_w as f32 / im_w as f32, final_h as f32 / im_h as f32) } diff --git a/tools/modnet.rs b/tools/modnet.rs index d82dd61..6598704 100644 --- a/tools/modnet.rs +++ b/tools/modnet.rs @@ -4,8 +4,8 @@ use bevy_ort::{ BevyOrtPlugin, inputs, models::modnet::{ - image_to_modnet_input, - modnet_output_to_luma_image, + images_to_modnet_input, + modnet_output_to_luma_images, }, Onnx, }; @@ -24,12 +24,14 @@ fn main() { .run(); } + #[derive(Resource, Default)] pub struct Modnet { pub onnx: Handle, pub input: Handle, } + fn load_modnet( asset_server: Res, mut modnet: ResMut, @@ -54,7 +56,7 @@ fn inference( } let image = images.get(&modnet.input).expect("failed to get image asset"); - let input = image_to_modnet_input(image); + let input = images_to_modnet_input(vec![&image], Some((256, 144))); let output: Result, String> = (|| { let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?; @@ -68,7 +70,7 @@ fn inference( Ok(output) => { let output_value: &ort::Value = output.get("output").unwrap(); - let mask_image = modnet_output_to_luma_image(output_value); + let mask_image = modnet_output_to_luma_images(output_value).pop().unwrap(); let mask_image = images.add(mask_image); commands.spawn(NodeBundle {