这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions .github/workflows/bench.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
name: bench

on:
pull_request:
types: [ labeled, synchronize ]
branches: [ "main" ]

env:
CARGO_TERM_COLOR: always
RUST_BACKTRACE: 1

jobs:
bench:
if: contains(github.event.pull_request.labels.*.name, 'bench')

strategy:
fail-fast: false
matrix:
os: [windows-latest, macos-latest]

runs-on: ${{ matrix.os }}
timeout-minutes: 120

steps:
- uses: actions/checkout@v3
- uses: actions/cache@v3
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
key: ${{ runner.os }}-cargo-build-stable-${{ hashFiles('**/Cargo.toml') }}


- name: Install ONNX Runtime on Windows
if: matrix.os == 'windows-latest'
run: |
Invoke-WebRequest -Uri "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-win-x64-1.17.1.zip" -OutFile "onnxruntime.zip"
Expand-Archive -Path "onnxruntime.zip" -DestinationPath "$env:RUNNER_TEMP"
echo "ONNXRUNTIME_DIR=$env:RUNNER_TEMP\onnxruntime-win-x64-1.17.1" | Out-File -Append -Encoding ascii $env:GITHUB_ENV

- name: Install ONNX Runtime on macOS
if: matrix.os == 'macos-latest'
run: |
curl -L "https://github.com/microsoft/onnxruntime/releases/download/v1.17.1/onnxruntime-osx-x86_64-1.17.1.tgz" -o "onnxruntime.tgz"
mkdir -p $HOME/onnxruntime
tar -xzf onnxruntime.tgz -C $HOME/onnxruntime
echo "ONNXRUNTIME_DIR=$HOME/onnxruntime/onnxruntime-osx-x86_64-1.17.1" >> $GITHUB_ENV


- name: Set ONNX Runtime library path for macOS
if: matrix.os == 'macos-latest'
run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/libonnxruntime.dylib" >> $GITHUB_ENV

- name: Set ONNX Runtime library path for Windows
if: matrix.os == 'windows-latest'
run: echo "ORT_DYLIB_PATH=$ONNXRUNTIME_DIR/onnxruntime.dll" >> $GITHUB_ENV


- name: io benchmark
uses: boa-dev/criterion-compare-action@v3.2.4
with:
benchName: "modnet"
branchName: ${{ github.base_ref }}
token: ${{ secrets.GITHUB_TOKEN }}
13 changes: 11 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "bevy_ort"
description = "bevy ort (onnxruntime) plugin"
version = "0.5.0"
version = "0.6.0"
edition = "2021"
authors = ["mosure <mitchell@mosure.me>"]
license = "MIT"
Expand Down Expand Up @@ -32,13 +32,14 @@ default = [
"modnet",
]

modnet = []
modnet = ["rayon"]


[dependencies]
bevy_args = "1.3"
image = "0.24"
ndarray = "0.15"
rayon = { version = "1.8", optional = true }
thiserror = "1.0"
tokio = { version = "1.36", features = ["full"] }

Expand Down Expand Up @@ -68,6 +69,10 @@ features = [
]


[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }


[profile.dev.package."*"]
opt-level = 3

Expand All @@ -88,3 +93,7 @@ path = "src/lib.rs"
name = "modnet"
path = "tools/modnet.rs"


[[bench]]
name = "modnet"
harness = false
115 changes: 115 additions & 0 deletions benches/modnet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use criterion::{
BenchmarkId,
criterion_group,
criterion_main,
Criterion,
Throughput,
};

use bevy::{
prelude::*,
render::{
render_asset::RenderAssetUsages,
render_resource::{
Extent3d,
TextureDimension,
},
},
};
use bevy_ort::{
inputs,
models::modnet::{
modnet_output_to_luma_images,
images_to_modnet_input,
},
Session,
};
use ort::GraphOptimizationLevel;


const MAX_RESOLUTIONS: [(u32, u32); 4] = [
(256, 256),
(512, 512),
(1024, 1024),
(2048, 2048),
];

const STREAM_COUNT: usize = 16;


fn images_to_modnet_input_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("images_to_modnet_input");

MAX_RESOLUTIONS.iter()
.for_each(|(width, height)| {
let data = vec![0u8; (1920 * 1080 * 4) as usize];

let images = (0..STREAM_COUNT)
.map(|_|{
Image::new(
Extent3d {
width: 1920,
height: 1080,
depth_or_array_layers: 1,
},
TextureDimension::D2,
data.clone(),
bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb,
RenderAssetUsages::all(),
)
})
.collect::<Vec<_>>();

group.throughput(Throughput::Elements(STREAM_COUNT as u64));
group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &images, |b, images| {
let views = images.iter().map(|image| image).collect::<Vec<_>>();

b.iter(|| images_to_modnet_input(views.as_slice(), Some((*width, *height))));
});
});
}


fn modnet_output_to_luma_images_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("modnet_output_to_luma_images");

let session = Session::builder().unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
.commit_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap();

let data = vec![0u8; (1920 * 1080 * 4) as usize];
let image: Image = Image::new(
Extent3d {
width: 1920,
height: 1080,
depth_or_array_layers: 1,
},
TextureDimension::D2,
data.clone(),
bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb,
RenderAssetUsages::all(),
);

MAX_RESOLUTIONS.iter()
.for_each(|size_limit| {
let input = images_to_modnet_input(&[&image; STREAM_COUNT], size_limit.clone().into());
let input_values = inputs!["input" => input.view()].map_err(|e| e.to_string()).unwrap();

let outputs = session.run(input_values).map_err(|e| e.to_string());
let binding = outputs.ok().unwrap();
let output_value: &ort::Value = binding.get("output").unwrap();

group.throughput(Throughput::Elements(STREAM_COUNT as u64));
group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", size_limit.0, size_limit.1)), &output_value, |b, output_value| {
b.iter(|| modnet_output_to_luma_images(output_value));
});
});
}


criterion_group!{
name = modnet_benches;
config = Criterion::default().sample_size(10);
targets = images_to_modnet_input_benchmark, modnet_output_to_luma_images_benchmark
}
criterion_main!(modnet_benches);
71 changes: 28 additions & 43 deletions src/models/modnet.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use bevy::{prelude::*, render::render_asset::RenderAssetUsages};
use image::{DynamicImage, GenericImageView, imageops::FilterType, ImageBuffer, Luma, RgbImage};
use ndarray::{Array, Array4, ArrayView4, Axis, s};
use ndarray::{Array, Array4, ArrayView4};
use rayon::prelude::*;


pub fn modnet_output_to_luma_images(
Expand Down Expand Up @@ -39,7 +40,7 @@ pub fn modnet_output_to_luma_images(


pub fn images_to_modnet_input(
images: Vec<&Image>,
images: &[&Image],
max_size: Option<(u32, u32)>,
) -> Array4<f32> {
if images.is_empty() {
Expand All @@ -49,86 +50,70 @@ pub fn images_to_modnet_input(
let ref_size = 512;
let &first_image = images.first().unwrap();

let image = first_image.to_owned();
let (x_scale, y_scale) = get_scale_factor(first_image.height(), first_image.width(), ref_size, max_size);

let (x_scale, y_scale) = get_scale_factor(image.height(), image.width(), ref_size, max_size);
let resized_image = resize_image(&image.try_into_dynamic().unwrap(), x_scale, y_scale);
let first_image_ndarray = image_to_ndarray(&resized_image);

let single_image_shape = first_image_ndarray.dim();
let n_images = images.len();
let batch_shape = (n_images, single_image_shape.1, single_image_shape.2, single_image_shape.3);

let mut aggregate = Array4::<f32>::zeros(batch_shape);

for (i, &image) in images.iter().enumerate() {
let image = image.to_owned();
let (x_scale, y_scale) = get_scale_factor(image.height(), image.width(), ref_size, max_size);
let resized_image = resize_image(&image.try_into_dynamic().unwrap(), x_scale, y_scale);
let image_ndarray = image_to_ndarray(&resized_image);
let processed_images: Vec<Array4<f32>> = images
.par_iter()
.map(|&image| {
let resized_image = resize_image(&image.clone().try_into_dynamic().unwrap(), x_scale, y_scale);
image_to_ndarray(&resized_image)
})
.collect();

let slice = s![i, .., .., ..];
aggregate.slice_mut(slice).assign(&image_ndarray.index_axis_move(Axis(0), 0));
}
let aggregate = Array::from_shape_vec(
(processed_images.len(), processed_images[0].shape()[1], processed_images[0].shape()[2], processed_images[0].shape()[3]),
processed_images.iter().flat_map(|a| a.iter().cloned()).collect(),
).unwrap();

aggregate
}


fn get_scale_factor(im_h: u32, im_w: u32, ref_size: u32, max_size: Option<(u32, u32)>) -> (f32, f32) {
// Calculate the scale factor based on the maximum size constraints
let scale_factor_max = max_size.map_or(1.0, |(max_w, max_h)| {
f32::min(max_w as f32 / im_w as f32, max_h as f32 / im_h as f32)
});

// Calculate the target dimensions after applying the max scale factor (clipping to max_size)
let (target_h, target_w) = ((im_h as f32 * scale_factor_max).round() as u32, (im_w as f32 * scale_factor_max).round() as u32);

// Calculate the scale factor to fit within the reference size, considering the target dimensions
let (scale_factor_ref_w, scale_factor_ref_h) = if std::cmp::max(target_h, target_w) < ref_size {
let scale_factor = ref_size as f32 / std::cmp::max(target_h, target_w) as f32;
(scale_factor, scale_factor)
} else {
(1.0, 1.0) // Do not upscale if target dimensions are within reference size
(1.0, 1.0)
};

// Calculate the final scale factor as the minimum of the max scale factor and the reference scale factor
let final_scale_w = f32::min(scale_factor_max, scale_factor_ref_w);
let final_scale_h = f32::min(scale_factor_max, scale_factor_ref_h);

// Adjust dimensions to ensure they are multiples of 32
let final_w = ((im_w as f32 * final_scale_w).round() as u32) - ((im_w as f32 * final_scale_w).round() as u32) % 32;
let final_h = ((im_h as f32 * final_scale_h).round() as u32) - ((im_h as f32 * final_scale_h).round() as u32) % 32;

// Return the scale factors based on the original image dimensions
(final_w as f32 / im_w as f32, final_h as f32 / im_h as f32)
}


fn image_to_ndarray(img: &RgbImage) -> Array4<f32> {
let (width, height) = img.dimensions();

// convert RgbImage to a Vec of f32 values normalized to [-1, 1]
let raw: Vec<f32> = img.pixels()
.flat_map(|p| {
p.0.iter().map(|&e| {
(e as f32 - 127.5) / 127.5
})
})
.collect();

// create a 3D array from the raw pixel data
let arr = Array::from_shape_vec((height as usize, width as usize, 3), raw)
.expect("failed to create ndarray from image raw data");
let arr = Array::from_shape_fn((1, 3, height as usize, width as usize), |(_, c, y, x)| {
let pixel = img.get_pixel(x as u32, y as u32);
let channel_value = match c {
0 => pixel[0],
1 => pixel[1],
2 => pixel[2],
_ => unreachable!(),
};
(channel_value as f32 - 127.5) / 127.5
});

// rearrange the dimensions from [height, width, channels] to [1, channels, height, width]
arr.permuted_axes([2, 0, 1]).insert_axis(Axis(0))
arr
}

fn resize_image(image: &DynamicImage, x_scale: f32, y_scale: f32) -> RgbImage {
let (width, height) = image.dimensions();
let new_width = (width as f32 * x_scale) as u32;
let new_height = (height as f32 * y_scale) as u32;

image.resize_exact(new_width, new_height, FilterType::Triangle).to_rgb8()
image.resize_exact(new_width, new_height, FilterType::Nearest).to_rgb8()
}
2 changes: 1 addition & 1 deletion tools/modnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn inference(
}

let image = images.get(&modnet.input).expect("failed to get image asset");
let input = images_to_modnet_input(vec![&image], None);
let input = images_to_modnet_input(&[image], Some((256, 256)));

let mask_image: Result<Image, String> = (|| {
let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?;
Expand Down