From 8235def89496767e3be5153f851d0c74d55ed3a2 Mon Sep 17 00:00:00 2001 From: mosure Date: Fri, 8 Mar 2024 19:58:19 -0600 Subject: [PATCH 01/13] docs: accelerated execution providers --- Cargo.toml | 9 ++++++++- README.md | 7 +++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index c5efd7b..80c8b3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,6 @@ modnet = [] bevy_args = "1.3" image = "0.24" ndarray = "0.15" -ort = "2.0.0-alpha.4" thiserror = "1.0" tokio = { version = "1.36", features = ["full"] } @@ -57,6 +56,14 @@ features = [ ] +[dependencies.ort] +version = "2.0.0-alpha.4" +default-features = true +features = [ + +] + + [profile.dev.package."*"] opt-level = 3 diff --git a/README.md b/README.md index 4cebfea..4a929bb 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,13 @@ fn inference( cargo run ``` +use an accelerated execution provider: +- windows - `cargo run --features ort/cuda` or `cargo run --features ort/openvino` +- macos - `cargo run --features ort/coreml` +- linux - `cargo run --features ort/tensorrt` or `cargo run --features ort/openvino` + +> see complete list of ort features here: https://github.com/pykeio/ort/blob/0aec4030a5f3470e4ee6c6f4e7e52d4e495ec27a/Cargo.toml#L54 + > note: if you use `pip install onnxruntime`, you may need to run `ORT_STRATEGY=system cargo run`, see: https://docs.rs/ort/latest/ort/#how-to-get-binaries From 72ba07ad45668337a98f41267f92442a9ed6b693 Mon Sep 17 00:00:00 2001 From: Mitchell Mosure Date: Fri, 8 Mar 2024 20:11:27 -0600 Subject: [PATCH 02/13] nit: bad formatting --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 4a929bb..bffd229 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ a bevy plugin for the [ort](https://docs.rs/ort/latest/ort/) library ![person](assets/person.png) ![mask](assets/mask.png) + *> modnet inference example* From 4b3a145646e62eac332d80d8f6008a3185b22266 Mon Sep 17 00:00:00 2001 From: mosure Date: Fri, 8 Mar 2024 20:17:16 -0600 Subject: [PATCH 03/13] chore: try to install onnxruntime in ci --- .github/workflows/test.yml | 6 ++++++ tools/modnet.rs | 2 ++ 2 files changed, 8 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8ceb28c..56a87e4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -25,6 +25,12 @@ jobs: steps: - uses: actions/checkout@v3 + - uses: actions/setup-python@v5 + with: + python-version: '3.9' + cache: 'pip' + - run: pip install onnxruntime + - name: Setup ${{ matrix.rust-toolchain }} rust toolchain with caching uses: brndnmtthws/rust-action@v1 with: diff --git a/tools/modnet.rs b/tools/modnet.rs index d82dd61..6e4f6eb 100644 --- a/tools/modnet.rs +++ b/tools/modnet.rs @@ -24,12 +24,14 @@ fn main() { .run(); } + #[derive(Resource, Default)] pub struct Modnet { pub onnx: Handle, pub input: Handle, } + fn load_modnet( asset_server: Res, mut modnet: ResMut, From ad703e9f6bdec3698aba943da3d76527ec07be0b Mon Sep 17 00:00:00 2001 From: mosure Date: Fri, 8 Mar 2024 20:42:55 -0600 Subject: [PATCH 04/13] chore: ORT_STRATEGY=system --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 56a87e4..8f38927 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -37,10 +37,10 @@ jobs: toolchain: ${{ matrix.rust-toolchain }} - name: lint - run: cargo clippy -- -Dwarnings + run: clippy -- -Dwarnings - name: build - run: cargo build + run: ORT_STRATEGY=system cargo build # - name: build (web) # run: cargo build --example=minimal --target wasm32-unknown-unknown --release From 67b9cdbcce7c1caf39f08334102a923d2f867f3d Mon Sep 17 00:00:00 2001 From: mosure Date: Fri, 8 Mar 2024 20:43:33 -0600 Subject: [PATCH 05/13] fix: cargo clippy --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8f38927..6cd232f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -37,7 +37,7 @@ jobs: toolchain: ${{ matrix.rust-toolchain }} - name: lint - run: clippy -- -Dwarnings + run: cargo clippy -- -Dwarnings - name: build run: ORT_STRATEGY=system cargo build From 15a6ecc58575693d7c7a2e9e662b5cf5d2ea9d73 Mon Sep 17 00:00:00 2001 From: mosure Date: Fri, 8 Mar 2024 20:46:43 -0600 Subject: [PATCH 06/13] chore: disable pip cache --- .github/workflows/test.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6cd232f..e37fd2c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,7 +28,6 @@ jobs: - uses: actions/setup-python@v5 with: python-version: '3.9' - cache: 'pip' - run: pip install onnxruntime - name: Setup ${{ matrix.rust-toolchain }} rust toolchain with caching From cb2fb41695c103c2ffb61ac090d38839d4f1c365 Mon Sep 17 00:00:00 2001 From: mosure Date: Fri, 8 Mar 2024 21:00:05 -0600 Subject: [PATCH 07/13] chore: install pre-built onnxruntime --- .github/workflows/test.yml | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e37fd2c..2eaa4d9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - os: [windows-latest, macos-latest] + os: [windows-latest] rust-toolchain: - nightly @@ -25,21 +25,34 @@ jobs: steps: - uses: actions/checkout@v3 - - uses: actions/setup-python@v5 - with: - python-version: '3.9' - - run: pip install onnxruntime - - name: Setup ${{ matrix.rust-toolchain }} rust toolchain with caching uses: brndnmtthws/rust-action@v1 with: toolchain: ${{ matrix.rust-toolchain }} + + - name: Install ONNX Runtime on Windows + if: matrix.os == 'windows-latest' + run: | + Invoke-WebRequest -Uri "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-win-x64-1.17.1.zip" -OutFile "onnxruntime.zip" + Expand-Archive -Path "onnxruntime.zip" -DestinationPath "$env:RUNNER_TEMP" + echo "ONNXRUNTIME_DIR=$env:RUNNER_TEMP\onnxruntime-win-x64-1.17.1" | Out-File -Append -Encoding ascii $env:GITHUB_ENV + + - name: Install ONNX Runtime on macOS + if: matrix.os == 'macos-latest' + run: | + curl -L "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-osx-x64-1.17.1.tgz" -o "onnxruntime.tgz" + tar -xzf onnxruntime.tgz + cp -r onnxruntime-osx-x64-1.17.1 /usr/local/ + echo "ONNXRUNTIME_DIR=/usr/local/onnxruntime-osx-x64-1.17.1" >> $GITHUB_ENV + + + - name: lint run: cargo clippy -- -Dwarnings - name: build - run: ORT_STRATEGY=system cargo build + run: cargo build # - name: build (web) # run: cargo build --example=minimal --target wasm32-unknown-unknown --release From 9bec0e72fcea8beefd18543eeff00069b0d814ab Mon Sep 17 00:00:00 2001 From: mosure Date: Fri, 8 Mar 2024 21:10:40 -0600 Subject: [PATCH 08/13] chore: set ORT_DYLIB_PATH and feature ort/load-dynamic --- .github/workflows/test.yml | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2eaa4d9..3fa7968 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - os: [windows-latest] + os: [windows-latest, macos-latest] rust-toolchain: - nightly @@ -41,18 +41,25 @@ jobs: - name: Install ONNX Runtime on macOS if: matrix.os == 'macos-latest' run: | - curl -L "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-osx-x64-1.17.1.tgz" -o "onnxruntime.tgz" + curl -L "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-osx-x86_64-1.17.1.tgz" -o "onnxruntime.tgz" tar -xzf onnxruntime.tgz cp -r onnxruntime-osx-x64-1.17.1 /usr/local/ echo "ONNXRUNTIME_DIR=/usr/local/onnxruntime-osx-x64-1.17.1" >> $GITHUB_ENV + - name: Set ONNX Runtime library path for macOS + if: matrix.os == 'macos-latest' + run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/libonnxruntime.dylib" >> $GITHUB_ENV + + - name: Set ONNX Runtime library path for Windows + if: matrix.os == 'windows-latest' + run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/onnxruntime.dll" >> $GITHUB_ENV + - name: lint run: cargo clippy -- -Dwarnings - name: build - run: cargo build - - # - name: build (web) - # run: cargo build --example=minimal --target wasm32-unknown-unknown --release + run: cargo build --features "ort/load-dynamic" + env: + ORT_DYLIB_PATH: ${{ env.ORT_DYLIB_PATH }} From 00311ff73570ee7a6e400ec5a1b1967922714419 Mon Sep 17 00:00:00 2001 From: mosure Date: Fri, 8 Mar 2024 21:15:12 -0600 Subject: [PATCH 09/13] chore: simplify unpack --- .github/workflows/test.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3fa7968..0f46cd8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,8 +42,7 @@ jobs: if: matrix.os == 'macos-latest' run: | curl -L "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-osx-x86_64-1.17.1.tgz" -o "onnxruntime.tgz" - tar -xzf onnxruntime.tgz - cp -r onnxruntime-osx-x64-1.17.1 /usr/local/ + tar -xzf onnxruntime.tgz -C /usr/local/ echo "ONNXRUNTIME_DIR=/usr/local/onnxruntime-osx-x64-1.17.1" >> $GITHUB_ENV From 364e721e8a54b1259635482186e1bf349e527252 Mon Sep 17 00:00:00 2001 From: mosure Date: Fri, 8 Mar 2024 21:24:53 -0600 Subject: [PATCH 10/13] chore: install in home directory --- .github/workflows/test.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0f46cd8..db6a205 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,8 +42,9 @@ jobs: if: matrix.os == 'macos-latest' run: | curl -L "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-osx-x86_64-1.17.1.tgz" -o "onnxruntime.tgz" - tar -xzf onnxruntime.tgz -C /usr/local/ - echo "ONNXRUNTIME_DIR=/usr/local/onnxruntime-osx-x64-1.17.1" >> $GITHUB_ENV + mkdir -p $HOME/onnxruntime + tar -xzf onnxruntime.tgz -C $HOME/onnxruntime + echo "ONNXRUNTIME_DIR=$HOME/onnxruntime/onnxruntime-osx-x86_64-1.17.1" >> $GITHUB_ENV - name: Set ONNX Runtime library path for macOS From 1bdc671c4706d60edb0f95420ddd7beb763adcf5 Mon Sep 17 00:00:00 2001 From: mosure Date: Fri, 8 Mar 2024 23:36:45 -0600 Subject: [PATCH 11/13] feat: batch inference --- Cargo.toml | 2 +- README.md | 8 ++--- src/models/modnet.rs | 84 +++++++++++++++++++++++++++----------------- tools/modnet.rs | 8 ++--- 4 files changed, 60 insertions(+), 42 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 80c8b3c..12c521a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "bevy_ort" description = "bevy ort (onnxruntime) plugin" -version = "0.2.0" +version = "0.3.0" edition = "2021" authors = ["mosure "] license = "MIT" diff --git a/README.md b/README.md index bffd229..93d9aaa 100644 --- a/README.md +++ b/README.md @@ -33,8 +33,8 @@ use bevy_ort::{ BevyOrtPlugin, inputs, models::modnet::{ - image_to_modnet_input, - modnet_output_to_luma_image, + images_to_modnet_input, + modnet_output_to_luma_images, }, Onnx, }; @@ -82,7 +82,7 @@ fn inference( } let image = images.get(&modnet.input).expect("failed to get image asset"); - let input = image_to_modnet_input(image); + let input = images_to_modnet_input(vec![&image]); let output: Result, String> = (|| { let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?; @@ -96,7 +96,7 @@ fn inference( Ok(output) => { let output_value: &ort::Value = output.get("output").unwrap(); - let mask_image = modnet_output_to_luma_image(output_value); + let mask_image = modnet_output_to_luma_images(output_value).pop().unwrap(); let mask_image = images.add(mask_image); commands.spawn(NodeBundle { diff --git a/src/models/modnet.rs b/src/models/modnet.rs index 4900494..5a1cd71 100644 --- a/src/models/modnet.rs +++ b/src/models/modnet.rs @@ -5,61 +5,79 @@ use std::cmp::{ use bevy::{prelude::*, render::render_asset::RenderAssetUsages}; use image::{DynamicImage, GenericImageView, imageops::FilterType, ImageBuffer, Luma, RgbImage}; -use ndarray::{Array, Array4, ArrayView4, Axis}; +use ndarray::{Array, Array4, ArrayView4, Axis, s}; -pub fn modnet_output_to_luma_image( +pub fn modnet_output_to_luma_images( output_value: &ort::Value, -) -> Image { +) -> Vec { let tensor: ort::Tensor = output_value.extract_tensor::().unwrap(); let data = tensor.view(); let shape = data.shape(); + let batch_size = shape[0]; let width = shape[3]; let height = shape[2]; - let tensor_data = ArrayView4::from_shape((1, 1, height, width), data.as_slice().unwrap()) + let tensor_data = ArrayView4::from_shape((batch_size, 1, height, width), data.as_slice().unwrap()) .expect("failed to create ArrayView4 from shape and data"); - let mut imgbuf = ImageBuffer::, Vec>::new(width as u32, height as u32); + let mut images = Vec::new(); - for y in 0..height { - for x in 0..width { - let pixel_value = tensor_data[(0, 0, y, x)]; - let pixel_value = (pixel_value.clamp(0.0, 1.0) * 255.0) as u8; - imgbuf.put_pixel(x as u32, y as u32, Luma([pixel_value])); + for i in 0..batch_size { + let mut imgbuf = ImageBuffer::, Vec>::new(width as u32, height as u32); + + for y in 0..height { + for x in 0..width { + let pixel_value = tensor_data[(i, 0, y, x)]; + let pixel_value = (pixel_value.clamp(0.0, 1.0) * 255.0) as u8; + imgbuf.put_pixel(x as u32, y as u32, Luma([pixel_value])); + } } - } - let dyn_img = DynamicImage::ImageLuma8(imgbuf); + let dyn_img = DynamicImage::ImageLuma8(imgbuf); - Image::from_dynamic(dyn_img, false, RenderAssetUsages::all()) -} + images.push(Image::from_dynamic(dyn_img, false, RenderAssetUsages::all())); + } + images +} -pub fn image_to_modnet_input( - image: &Image, +pub fn images_to_modnet_input( + images: Vec<&Image>, ) -> Array4 { - assert_eq!(image.texture_descriptor.format, bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb); + // TODO: better error handling + if images.is_empty() { + panic!("no images provided"); + } let ref_size = 512; - let ( - x_scale, - y_scale, - ) = get_scale_factor( - image.height(), - image.width(), - ref_size, - ); - - let resized_image = resize_image( - &image.clone().try_into_dynamic().unwrap(), - x_scale, - y_scale, - ); - - image_to_ndarray(&resized_image) + + let &first_image = images.first().unwrap(); + assert_eq!(first_image.texture_descriptor.format, bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb); + + let dynamic_image = first_image.clone().try_into_dynamic().unwrap(); + let (x_scale, y_scale) = get_scale_factor(dynamic_image.height(), dynamic_image.width(), ref_size); + let resized_image = resize_image(&dynamic_image, x_scale, y_scale); + let first_image_ndarray = image_to_ndarray(&resized_image); + let single_image_shape = first_image_ndarray.dim(); + let n_images = images.len(); + let batch_shape = (n_images, single_image_shape.1, single_image_shape.2, single_image_shape.3); + + let mut aggregate = Array4::::zeros(batch_shape); + + for (i, &image) in images.iter().enumerate() { + let dynamic_image = image.clone().try_into_dynamic().unwrap(); + let (x_scale, y_scale) = get_scale_factor(dynamic_image.height(), dynamic_image.width(), ref_size); + let resized_image = resize_image(&dynamic_image, x_scale, y_scale); + let image_ndarray = image_to_ndarray(&resized_image); + + let slice = s![i, .., .., ..]; + aggregate.slice_mut(slice).assign(&image_ndarray.index_axis_move(Axis(0), 0)); + } + + aggregate } diff --git a/tools/modnet.rs b/tools/modnet.rs index 6e4f6eb..3e82e05 100644 --- a/tools/modnet.rs +++ b/tools/modnet.rs @@ -4,8 +4,8 @@ use bevy_ort::{ BevyOrtPlugin, inputs, models::modnet::{ - image_to_modnet_input, - modnet_output_to_luma_image, + images_to_modnet_input, + modnet_output_to_luma_images, }, Onnx, }; @@ -56,7 +56,7 @@ fn inference( } let image = images.get(&modnet.input).expect("failed to get image asset"); - let input = image_to_modnet_input(image); + let input = images_to_modnet_input(vec![&image]); let output: Result, String> = (|| { let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?; @@ -70,7 +70,7 @@ fn inference( Ok(output) => { let output_value: &ort::Value = output.get("output").unwrap(); - let mask_image = modnet_output_to_luma_image(output_value); + let mask_image = modnet_output_to_luma_images(output_value).pop().unwrap(); let mask_image = images.add(mask_image); commands.spawn(NodeBundle { From 6b7b48d38ef794bbeab81dca82e7ffd7aa895b90 Mon Sep 17 00:00:00 2001 From: mosure Date: Fri, 8 Mar 2024 23:37:21 -0600 Subject: [PATCH 12/13] docs: capability --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 93d9aaa..a9c0279 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ a bevy plugin for the [ort](https://docs.rs/ort/latest/ort/) library - [X] load ONNX models as ORT session assets - [X] initialize ORT with default execution providers - [X] modnet bevy image <-> ort tensor IO (with feature `modnet`) +- [X] batched modnet preprocessing - [ ] compute task pool inference scheduling From d3389ca6b829aecb6514fd375d0b0f1f93e69f8a Mon Sep 17 00:00:00 2001 From: mosure Date: Sat, 9 Mar 2024 00:28:37 -0600 Subject: [PATCH 13/13] feat: specify max inference size --- Cargo.toml | 2 +- src/models/modnet.rs | 69 ++++++++++++++++++++++++-------------------- tools/modnet.rs | 2 +- 3 files changed, 40 insertions(+), 33 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 12c521a..22a0071 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "bevy_ort" description = "bevy ort (onnxruntime) plugin" -version = "0.3.0" +version = "0.4.0" edition = "2021" authors = ["mosure "] license = "MIT" diff --git a/src/models/modnet.rs b/src/models/modnet.rs index 5a1cd71..7c9de00 100644 --- a/src/models/modnet.rs +++ b/src/models/modnet.rs @@ -1,8 +1,3 @@ -use std::cmp::{ - max, - min, -}; - use bevy::{prelude::*, render::render_asset::RenderAssetUsages}; use image::{DynamicImage, GenericImageView, imageops::FilterType, ImageBuffer, Luma, RgbImage}; use ndarray::{Array, Array4, ArrayView4, Axis, s}; @@ -44,23 +39,29 @@ pub fn modnet_output_to_luma_images( images } + pub fn images_to_modnet_input( images: Vec<&Image>, + max_size: Option<(u32, u32)>, ) -> Array4 { - // TODO: better error handling if images.is_empty() { panic!("no images provided"); } let ref_size = 512; - let &first_image = images.first().unwrap(); - assert_eq!(first_image.texture_descriptor.format, bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb); - let dynamic_image = first_image.clone().try_into_dynamic().unwrap(); - let (x_scale, y_scale) = get_scale_factor(dynamic_image.height(), dynamic_image.width(), ref_size); - let resized_image = resize_image(&dynamic_image, x_scale, y_scale); + let image = first_image.to_owned(); + + println!("image: {:?}", image.size()); + + let (x_scale, y_scale) = get_scale_factor(image.height(), image.width(), ref_size, max_size); + let resized_image = resize_image(&image.try_into_dynamic().unwrap(), x_scale, y_scale); let first_image_ndarray = image_to_ndarray(&resized_image); + + println!("scale_factor: {:?}", (x_scale, y_scale)); + println!("first_image_ndarray: {:?}", first_image_ndarray.dim()); + let single_image_shape = first_image_ndarray.dim(); let n_images = images.len(); let batch_shape = (n_images, single_image_shape.1, single_image_shape.2, single_image_shape.3); @@ -68,9 +69,9 @@ pub fn images_to_modnet_input( let mut aggregate = Array4::::zeros(batch_shape); for (i, &image) in images.iter().enumerate() { - let dynamic_image = image.clone().try_into_dynamic().unwrap(); - let (x_scale, y_scale) = get_scale_factor(dynamic_image.height(), dynamic_image.width(), ref_size); - let resized_image = resize_image(&dynamic_image, x_scale, y_scale); + let image = image.to_owned(); + let (x_scale, y_scale) = get_scale_factor(image.height(), image.width(), ref_size, max_size); + let resized_image = resize_image(&image.try_into_dynamic().unwrap(), x_scale, y_scale); let image_ndarray = image_to_ndarray(&resized_image); let slice = s![i, .., .., ..]; @@ -81,27 +82,33 @@ pub fn images_to_modnet_input( } -fn get_scale_factor(im_h: u32, im_w: u32, ref_size: u32) -> (f32, f32) { - let mut im_rh; - let mut im_rw; +fn get_scale_factor(im_h: u32, im_w: u32, ref_size: u32, max_size: Option<(u32, u32)>) -> (f32, f32) { + // Calculate the scale factor based on the maximum size constraints + let scale_factor_max = max_size.map_or(1.0, |(max_w, max_h)| { + f32::min(max_w as f32 / im_w as f32, max_h as f32 / im_h as f32) + }); - if max(im_h, im_w) < ref_size || min(im_h, im_w) > ref_size { - if im_w >= im_h { - im_rh = ref_size; - im_rw = (im_w as f32 / im_h as f32 * ref_size as f32) as u32; - } else { - im_rw = ref_size; - im_rh = (im_h as f32 / im_w as f32 * ref_size as f32) as u32; - } + // Calculate the target dimensions after applying the max scale factor (clipping to max_size) + let (target_h, target_w) = ((im_h as f32 * scale_factor_max).round() as u32, (im_w as f32 * scale_factor_max).round() as u32); + + // Calculate the scale factor to fit within the reference size, considering the target dimensions + let (scale_factor_ref_w, scale_factor_ref_h) = if std::cmp::max(target_h, target_w) < ref_size { + let scale_factor = ref_size as f32 / std::cmp::max(target_h, target_w) as f32; + (scale_factor, scale_factor) } else { - im_rh = im_h; - im_rw = im_w; - } + (1.0, 1.0) // Do not upscale if target dimensions are within reference size + }; + + // Calculate the final scale factor as the minimum of the max scale factor and the reference scale factor + let final_scale_w = f32::min(scale_factor_max, scale_factor_ref_w); + let final_scale_h = f32::min(scale_factor_max, scale_factor_ref_h); - im_rw = im_rw - im_rw % 32; - im_rh = im_rh - im_rh % 32; + // Adjust dimensions to ensure they are multiples of 32 + let final_w = ((im_w as f32 * final_scale_w).round() as u32) - ((im_w as f32 * final_scale_w).round() as u32) % 32; + let final_h = ((im_h as f32 * final_scale_h).round() as u32) - ((im_h as f32 * final_scale_h).round() as u32) % 32; - (im_rw as f32 / im_w as f32, im_rh as f32 / im_h as f32) + // Return the scale factors based on the original image dimensions + (final_w as f32 / im_w as f32, final_h as f32 / im_h as f32) } diff --git a/tools/modnet.rs b/tools/modnet.rs index 3e82e05..6598704 100644 --- a/tools/modnet.rs +++ b/tools/modnet.rs @@ -56,7 +56,7 @@ fn inference( } let image = images.get(&modnet.input).expect("failed to get image asset"); - let input = images_to_modnet_input(vec![&image]); + let input = images_to_modnet_input(vec![&image], Some((256, 144))); let output: Result, String> = (|| { let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?;