diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml new file mode 100644 index 0000000..d09219d --- /dev/null +++ b/.github/workflows/bench.yml @@ -0,0 +1,67 @@ +name: bench + +on: + pull_request: + types: [ labeled, synchronize ] + branches: [ "main" ] + +env: + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + +jobs: + bench: + if: contains(github.event.pull_request.labels.*.name, 'bench') + + strategy: + fail-fast: false + matrix: + os: [windows-latest, macos-latest] + + runs-on: ${{ matrix.os }} + timeout-minutes: 120 + + steps: + - uses: actions/checkout@v3 + - uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + key: ${{ runner.os }}-cargo-build-stable-${{ hashFiles('**/Cargo.toml') }} + + + - name: Install ONNX Runtime on Windows + if: matrix.os == 'windows-latest' + run: | + Invoke-WebRequest -Uri "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-win-x64-1.17.1.zip" -OutFile "onnxruntime.zip" + Expand-Archive -Path "onnxruntime.zip" -DestinationPath "$env:RUNNER_TEMP" + echo "ONNXRUNTIME_DIR=$env:RUNNER_TEMP\onnxruntime-win-x64-1.17.1" | Out-File -Append -Encoding ascii $env:GITHUB_ENV + + - name: Install ONNX Runtime on macOS + if: matrix.os == 'macos-latest' + run: | + curl -L "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-osx-x86_64-1.17.1.tgz" -o "onnxruntime.tgz" + mkdir -p $HOME/onnxruntime + tar -xzf onnxruntime.tgz -C $HOME/onnxruntime + echo "ONNXRUNTIME_DIR=$HOME/onnxruntime/onnxruntime-osx-x86_64-1.17.1" >> $GITHUB_ENV + + + - name: Set ONNX Runtime library path for macOS + if: matrix.os == 'macos-latest' + run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/libonnxruntime.dylib" >> $GITHUB_ENV + + - name: Set ONNX Runtime library path for Windows + if: matrix.os == 'windows-latest' + run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/onnxruntime.dll" >> $GITHUB_ENV + + + - name: io benchmark + uses: boa-dev/criterion-compare-action@v3.2.4 + with: + benchName: "modnet" + branchName: ${{ github.base_ref }} + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/Cargo.toml b/Cargo.toml index 33e65c8..30b711e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "bevy_ort" description = "bevy ort (onnxruntime) plugin" -version = "0.5.0" +version = "0.6.0" edition = "2021" authors = ["mosure "] license = "MIT" @@ -26,18 +26,20 @@ exclude = [ default-run = "modnet" + [features] default = [ "modnet", ] -modnet = [] +modnet = ["rayon"] [dependencies] bevy_args = "1.3" image = "0.24" ndarray = "0.15" +rayon = { version = "1.8", optional = true } thiserror = "1.0" tokio = { version = "1.36", features = ["full"] } @@ -57,14 +59,19 @@ features = [ [dependencies.ort] -version = "2.0.0-alpha.4" +version = "2.0.0-rc.0" default-features = false features = [ + "cuda", "load-dynamic", "ndarray", ] +[dev-dependencies] +criterion = { version = "0.5", features = ["html_reports"] } + + [profile.dev.package."*"] opt-level = 3 @@ -85,3 +92,7 @@ path = "src/lib.rs" name = "modnet" path = "tools/modnet.rs" + +[[bench]] +name = "modnet" +harness = false diff --git a/README.md b/README.md index 2f81f48..7628871 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ fn inference( } let image = images.get(&modnet.input).expect("failed to get image asset"); - let input = images_to_modnet_input(vec![&image]); + let input = images_to_modnet_input(vec![&image], None); let mask_image: Result = (|| { let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?; diff --git a/benches/modnet.rs b/benches/modnet.rs new file mode 100644 index 0000000..0bb6500 --- /dev/null +++ b/benches/modnet.rs @@ -0,0 +1,115 @@ +use criterion::{ + BenchmarkId, + criterion_group, + criterion_main, + Criterion, + Throughput, +}; + +use bevy::{ + prelude::*, + render::{ + render_asset::RenderAssetUsages, + render_resource::{ + Extent3d, + TextureDimension, + }, + }, +}; +use bevy_ort::{ + inputs, + models::modnet::{ + modnet_output_to_luma_images, + images_to_modnet_input, + }, + Session, +}; +use ort::GraphOptimizationLevel; + + +const MAX_RESOLUTIONS: [(u32, u32); 4] = [ + (256, 256), + (512, 512), + (1024, 1024), + (2048, 2048), +]; + +const STREAM_COUNT: usize = 16; + + +fn images_to_modnet_input_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("images_to_modnet_input"); + + MAX_RESOLUTIONS.iter() + .for_each(|(width, height)| { + let data = vec![0u8; (1920 * 1080 * 4) as usize]; + + let images = (0..STREAM_COUNT) + .map(|_|{ + Image::new( + Extent3d { + width: 1920, + height: 1080, + depth_or_array_layers: 1, + }, + TextureDimension::D2, + data.clone(), + bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb, + RenderAssetUsages::all(), + ) + }) + .collect::>(); + + group.throughput(Throughput::Elements(STREAM_COUNT as u64)); + group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &images, |b, images| { + let views = images.iter().map(|image| image).collect::>(); + + b.iter(|| images_to_modnet_input(views.as_slice(), Some((*width, *height)))); + }); + }); +} + + +fn modnet_output_to_luma_images_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("modnet_output_to_luma_images"); + + let session = Session::builder().unwrap() + .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() + .with_model_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap(); + + let data = vec![0u8; (1920 * 1080 * 4) as usize]; + let image: Image = Image::new( + Extent3d { + width: 1920, + height: 1080, + depth_or_array_layers: 1, + }, + TextureDimension::D2, + data.clone(), + bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb, + RenderAssetUsages::all(), + ); + + MAX_RESOLUTIONS.iter() + .for_each(|size_limit| { + let input = images_to_modnet_input(&[ℑ STREAM_COUNT], size_limit.clone().into()); + let input_values = inputs!["input" => input.view()].map_err(|e| e.to_string()).unwrap(); + + let outputs = session.run(input_values).map_err(|e| e.to_string()); + let binding = outputs.ok().unwrap(); + let output_value: &ort::Value = binding.get("output").unwrap(); + + group.throughput(Throughput::Elements(STREAM_COUNT as u64)); + group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", size_limit.0, size_limit.1)), &output_value, |b, output_value| { + b.iter(|| modnet_output_to_luma_images(output_value)); + }); + }); +} + + +criterion_group!{ + name = modnet_benches; + config = Criterion::default().sample_size(10); + targets = images_to_modnet_input_benchmark, modnet_output_to_luma_images_benchmark +} +criterion_main!(modnet_benches); diff --git a/src/models/modnet.rs b/src/models/modnet.rs index 7721100..094569c 100644 --- a/src/models/modnet.rs +++ b/src/models/modnet.rs @@ -1,13 +1,13 @@ use bevy::{prelude::*, render::render_asset::RenderAssetUsages}; use image::{DynamicImage, GenericImageView, imageops::FilterType, ImageBuffer, Luma, RgbImage}; -use ndarray::{Array, Array4, ArrayView4, Axis, s}; +use ndarray::{Array, Array4, ArrayView4}; +use rayon::prelude::*; pub fn modnet_output_to_luma_images( output_value: &ort::Value, ) -> Vec { - let tensor: ort::Tensor = output_value.extract_tensor::().unwrap(); - + let tensor = output_value.extract_tensor::().unwrap(); let data = tensor.view(); let shape = data.shape(); @@ -41,7 +41,7 @@ pub fn modnet_output_to_luma_images( pub fn images_to_modnet_input( - images: Vec<&Image>, + images: &[&Image], max_size: Option<(u32, u32)>, ) -> Array4 { if images.is_empty() { @@ -51,58 +51,45 @@ pub fn images_to_modnet_input( let ref_size = 512; let &first_image = images.first().unwrap(); - let image = first_image.to_owned(); - - let (x_scale, y_scale) = get_scale_factor(image.height(), image.width(), ref_size, max_size); - let resized_image = resize_image(&image.try_into_dynamic().unwrap(), x_scale, y_scale); - let first_image_ndarray = image_to_ndarray(&resized_image); - - let single_image_shape = first_image_ndarray.dim(); - let n_images = images.len(); - let batch_shape = (n_images, single_image_shape.1, single_image_shape.2, single_image_shape.3); + let (x_scale, y_scale) = get_scale_factor(first_image.height(), first_image.width(), ref_size, max_size); - let mut aggregate = Array4::::zeros(batch_shape); + let processed_images: Vec> = images + .par_iter() + .map(|&image| { + let resized_image = resize_image(&image.clone().try_into_dynamic().unwrap(), x_scale, y_scale); + image_to_ndarray(&resized_image) + }) + .collect(); - for (i, &image) in images.iter().enumerate() { - let image = image.to_owned(); - let (x_scale, y_scale) = get_scale_factor(image.height(), image.width(), ref_size, max_size); - let resized_image = resize_image(&image.try_into_dynamic().unwrap(), x_scale, y_scale); - let image_ndarray = image_to_ndarray(&resized_image); - - let slice = s![i, .., .., ..]; - aggregate.slice_mut(slice).assign(&image_ndarray.index_axis_move(Axis(0), 0)); - } + let aggregate = Array::from_shape_vec( + (processed_images.len(), processed_images[0].shape()[1], processed_images[0].shape()[2], processed_images[0].shape()[3]), + processed_images.iter().flat_map(|a| a.iter().cloned()).collect(), + ).unwrap(); aggregate } fn get_scale_factor(im_h: u32, im_w: u32, ref_size: u32, max_size: Option<(u32, u32)>) -> (f32, f32) { - // Calculate the scale factor based on the maximum size constraints let scale_factor_max = max_size.map_or(1.0, |(max_w, max_h)| { f32::min(max_w as f32 / im_w as f32, max_h as f32 / im_h as f32) }); - // Calculate the target dimensions after applying the max scale factor (clipping to max_size) let (target_h, target_w) = ((im_h as f32 * scale_factor_max).round() as u32, (im_w as f32 * scale_factor_max).round() as u32); - // Calculate the scale factor to fit within the reference size, considering the target dimensions let (scale_factor_ref_w, scale_factor_ref_h) = if std::cmp::max(target_h, target_w) < ref_size { let scale_factor = ref_size as f32 / std::cmp::max(target_h, target_w) as f32; (scale_factor, scale_factor) } else { - (1.0, 1.0) // Do not upscale if target dimensions are within reference size + (1.0, 1.0) }; - // Calculate the final scale factor as the minimum of the max scale factor and the reference scale factor let final_scale_w = f32::min(scale_factor_max, scale_factor_ref_w); let final_scale_h = f32::min(scale_factor_max, scale_factor_ref_h); - // Adjust dimensions to ensure they are multiples of 32 let final_w = ((im_w as f32 * final_scale_w).round() as u32) - ((im_w as f32 * final_scale_w).round() as u32) % 32; let final_h = ((im_h as f32 * final_scale_h).round() as u32) - ((im_h as f32 * final_scale_h).round() as u32) % 32; - // Return the scale factors based on the original image dimensions (final_w as f32 / im_w as f32, final_h as f32 / im_h as f32) } @@ -110,21 +97,18 @@ fn get_scale_factor(im_h: u32, im_w: u32, ref_size: u32, max_size: Option<(u32, fn image_to_ndarray(img: &RgbImage) -> Array4 { let (width, height) = img.dimensions(); - // convert RgbImage to a Vec of f32 values normalized to [-1, 1] - let raw: Vec = img.pixels() - .flat_map(|p| { - p.0.iter().map(|&e| { - (e as f32 - 127.5) / 127.5 - }) - }) - .collect(); - - // create a 3D array from the raw pixel data - let arr = Array::from_shape_vec((height as usize, width as usize, 3), raw) - .expect("failed to create ndarray from image raw data"); + let arr = Array::from_shape_fn((1, 3, height as usize, width as usize), |(_, c, y, x)| { + let pixel = img.get_pixel(x as u32, y as u32); + let channel_value = match c { + 0 => pixel[0], + 1 => pixel[1], + 2 => pixel[2], + _ => unreachable!(), + }; + (channel_value as f32 - 127.5) / 127.5 + }); - // rearrange the dimensions from [height, width, channels] to [1, channels, height, width] - arr.permuted_axes([2, 0, 1]).insert_axis(Axis(0)) + arr } fn resize_image(image: &DynamicImage, x_scale: f32, y_scale: f32) -> RgbImage { @@ -132,5 +116,5 @@ fn resize_image(image: &DynamicImage, x_scale: f32, y_scale: f32) -> RgbImage { let new_width = (width as f32 * x_scale) as u32; let new_height = (height as f32 * y_scale) as u32; - image.resize_exact(new_width, new_height, FilterType::Triangle).to_rgb8() + image.resize_exact(new_width, new_height, FilterType::Nearest).to_rgb8() } diff --git a/tools/modnet.rs b/tools/modnet.rs index 43a6957..5a87db5 100644 --- a/tools/modnet.rs +++ b/tools/modnet.rs @@ -56,7 +56,7 @@ fn inference( } let image = images.get(&modnet.input).expect("failed to get image asset"); - let input = images_to_modnet_input(vec![&image], Some((256, 144))); + let input = images_to_modnet_input(&[image], Some((256, 256))); let mask_image: Result = (|| { let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?;