diff --git a/Cargo.toml b/Cargo.toml index a8b7a51..f8de177 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "bevy_ort" description = "bevy ort (onnxruntime) plugin" -version = "0.6.1" +version = "0.6.3" edition = "2021" authors = ["mosure "] license = "MIT" diff --git a/README.md b/README.md index 7628871..506e6d9 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,6 @@ # bevy_ort 🪨 [![test](https://github.com/mosure/bevy_ort/workflows/test/badge.svg)](https://github.com/Mosure/bevy_ort/actions?query=workflow%3Atest) [![GitHub License](https://img.shields.io/github/license/mosure/bevy_ort)](https://raw.githubusercontent.com/mosure/bevy_ort/main/LICENSE) -[![GitHub Last Commit](https://img.shields.io/github/last-commit/mosure/bevy_ort)](https://github.com/mosure/bevy_ort) -[![GitHub Releases](https://img.shields.io/github/v/release/mosure/bevy_ort?include_prereleases&sort=semver)](https://github.com/mosure/bevy_ort/releases) -[![GitHub Issues](https://img.shields.io/github/issues/mosure/bevy_ort)](https://github.com/mosure/bevy_ort/issues) -[![Average time to resolve an issue](https://isitmaintained.com/badge/resolution/mosure/bevy_ort.svg)](http://isitmaintained.com/project/mosure/bevy_ort) [![crates.io](https://img.shields.io/crates/v/bevy_ort.svg)](https://crates.io/crates/bevy_ort) a bevy plugin for the [ort](https://docs.rs/ort/latest/ort/) library diff --git a/benches/modnet.rs b/benches/modnet.rs index 0bb6500..58db29c 100644 --- a/benches/modnet.rs +++ b/benches/modnet.rs @@ -37,6 +37,16 @@ const MAX_RESOLUTIONS: [(u32, u32); 4] = [ const STREAM_COUNT: usize = 16; +criterion_group!{ + name = modnet_benches; + config = Criterion::default().sample_size(10); + targets = images_to_modnet_input_benchmark, + modnet_output_to_luma_images_benchmark, + modnet_inference_benchmark, +} +criterion_main!(modnet_benches); + + fn images_to_modnet_input_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("images_to_modnet_input"); @@ -107,9 +117,40 @@ fn modnet_output_to_luma_images_benchmark(c: &mut Criterion) { } -criterion_group!{ - name = modnet_benches; - config = Criterion::default().sample_size(10); - targets = images_to_modnet_input_benchmark, modnet_output_to_luma_images_benchmark +fn modnet_inference_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("modnet_inference"); + + let session = Session::builder().unwrap() + .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() + .with_model_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap(); + + MAX_RESOLUTIONS.iter().for_each(|(width, height)| { + let data = vec![0u8; *width as usize * *height as usize * 4]; + let image = Image::new( + Extent3d { + width: *width, + height: *height, + depth_or_array_layers: 1, + }, + TextureDimension::D2, + data.clone(), + bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb, + RenderAssetUsages::all(), + ); + + let input = images_to_modnet_input(&[&image], Some((*width, *height))); + + group.throughput(Throughput::Elements(1)); + group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &(width, height), |b, _| { + b.iter(|| { + let input_values = inputs!["input" => input.view()].map_err(|e| e.to_string()).unwrap(); + let outputs = session.run(input_values).map_err(|e| e.to_string()); + let binding = outputs.ok().unwrap(); + let output_value: &ort::Value = binding.get("output").unwrap(); + modnet_output_to_luma_images(output_value); + }); + }); + }); + + group.finish(); } -criterion_main!(modnet_benches); diff --git a/src/models/modnet.rs b/src/models/modnet.rs index 6cbfe6f..48c82c0 100644 --- a/src/models/modnet.rs +++ b/src/models/modnet.rs @@ -1,4 +1,14 @@ -use bevy::{prelude::*, render::render_asset::RenderAssetUsages}; +use bevy::{ + prelude::*, + render::{ + render_asset::RenderAssetUsages, + render_resource::{ + Extent3d, + TextureDimension, + TextureFormat, + }, + }, +}; use image::{DynamicImage, GenericImageView, imageops::FilterType, ImageBuffer, Luma, RgbImage}; use ndarray::{Array, Array4, ArrayView4}; use rayon::prelude::*; @@ -18,25 +28,32 @@ pub fn modnet_output_to_luma_images( let tensor_data = ArrayView4::from_shape((batch_size, 1, height, width), data.as_slice().unwrap()) .expect("failed to create ArrayView4 from shape and data"); - let mut images = Vec::new(); - - for i in 0..batch_size { - let mut imgbuf = ImageBuffer::, Vec>::new(width as u32, height as u32); - - for y in 0..height { - for x in 0..width { - let pixel_value = tensor_data[(i, 0, y, x)]; - let pixel_value = (pixel_value.clamp(0.0, 1.0) * 255.0) as u8; - imgbuf.put_pixel(x as u32, y as u32, Luma([pixel_value])); + (0..batch_size) + .into_par_iter() + .map(|i| { + let mut imgbuf = ImageBuffer::, Vec>::new(width as u32, height as u32); + + for y in 0..height { + for x in 0..width { + let pixel_value = tensor_data[(i, 0, y, x)]; + let pixel_value = (pixel_value.clamp(0.0, 1.0) * 255.0) as u8; + imgbuf.put_pixel(x as u32, y as u32, Luma([pixel_value])); + } } - } - - let dyn_img = DynamicImage::ImageLuma8(imgbuf); - images.push(Image::from_dynamic(dyn_img, false, RenderAssetUsages::all())); - } - - images + Image::new( + Extent3d { + width: width as u32, + height: height as u32, + depth_or_array_layers: 1, + }, + TextureDimension::D2, + imgbuf.into_raw(), + TextureFormat::R8Unorm, + RenderAssetUsages::all(), + ) + }) + .collect::>() } diff --git a/tools/modnet.rs b/tools/modnet.rs index 5a87db5..34eab18 100644 --- a/tools/modnet.rs +++ b/tools/modnet.rs @@ -56,7 +56,7 @@ fn inference( } let image = images.get(&modnet.input).expect("failed to get image asset"); - let input = images_to_modnet_input(&[image], Some((256, 256))); + let input = images_to_modnet_input(&[image], None); let mask_image: Result = (|| { let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?;