diff --git a/.gitignore b/.gitignore index 28fc69c..d74f255 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,6 @@ www/assets/ mediamtx/ onnxruntime/ + +*.onnx +*.bin diff --git a/Cargo.toml b/Cargo.toml index 6d53cdd..608f7c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "bevy_ort" description = "bevy ort (onnxruntime) plugin" -version = "0.9.0" +version = "0.12.9" edition = "2021" authors = ["mosure "] license = "MIT" @@ -29,11 +29,16 @@ default-run = "modnet" [features] default = [ + "flame", + "flame_viewer", "lightglue", "modnet", "yolo_v8", ] +flame_viewer = ["bevy_panorbit_camera"] + +flame = [] lightglue = [] modnet = ["rayon"] yolo_v8 = [] @@ -41,13 +46,14 @@ yolo_v8 = [] [dependencies] bevy_args = "1.3" +bevy_panorbit_camera = { version = "0.18", optional = true } +bytemuck = "1.15" image = "0.24" # upgrade with bevy +include_bytes_aligned = "0.1" ndarray = "0.15" rayon = { version = "1.8", optional = true } -serde = "1.0.197" +serde = "1.0" thiserror = "1.0" -tokio = { version = "1.36", features = ["full"] } - [dependencies.bevy] version = "0.13" @@ -55,20 +61,21 @@ default-features = false features = [ "bevy_asset", "bevy_core_pipeline", + "bevy_pbr", "bevy_render", "bevy_ui", "bevy_winit", "multi-threaded", "png", + "tonemapping_luts", ] [dependencies.ort] -version = "2.0.0-rc.0" +version = "2.0.0-rc.2" default-features = false features = [ - "cuda", - "load-dynamic", + "download-binaries", "ndarray", ] @@ -93,6 +100,11 @@ opt-level = 3 path = "src/lib.rs" +[[bin]] +name = "flame" +path = "tools/flame.rs" +required-features = ["flame", "flame_viewer"] + [[bin]] name = "lightglue" path = "tools/lightglue.rs" diff --git a/README.md b/README.md index f78698a..18e0a75 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ a bevy plugin for the [ort](https://docs.rs/ort/latest/ort/) library - [X] lightglue (feature matching) - [X] modnet (photographic portrait matting) - [X] yolo_v8 (object detection) +- [X] flame (parametric head model) ## library usage @@ -33,12 +34,12 @@ use bevy::prelude::*; use bevy_ort::{ BevyOrtPlugin, - inputs, - models::modnet::{ - images_to_modnet_input, - modnet_output_to_luma_images, + models::flame::{ + FlameInput, + FlameOutput, + Flame, + FlamePlugin, }, - Onnx, }; @@ -47,88 +48,51 @@ fn main() { .add_plugins(( DefaultPlugins, BevyOrtPlugin, + FlamePlugin, )) - .init_resource::() - .add_systems(Startup, load_modnet) - .add_systems(Update, inference) + .add_systems(Startup, load_flame) + .add_systems(Startup, setup) + .add_systems(Update, on_flame_output) .run(); } -#[derive(Resource, Default)] -pub struct Modnet { - pub onnx: Handle, - pub input: Handle, -} -fn load_modnet( +fn load_flame( asset_server: Res, - mut modnet: ResMut, + mut flame: ResMut, ) { - let modnet_handle: Handle = asset_server.load("modnet_photographic_portrait_matting.onnx"); - modnet.onnx = modnet_handle; + flame.onnx = asset_server.load("models/flame.onnx"); +} - let input_handle: Handle = asset_server.load("images/person.png"); - modnet.input = input_handle; + +fn setup( + mut commands: Commands, +) { + commands.spawn(FlameInput::default()); + commands.spawn(Camera3dBundle::default()); } -fn inference( +#[derive(Debug, Component, Reflect)] +struct HandledFlameOutput; + +fn on_flame_output( mut commands: Commands, - modnet: Res, - onnx_assets: Res>, - mut images: ResMut>, - mut complete: Local, + flame_outputs: Query< + ( + Entity, + &FlameOutput, + ), + Without, + >, ) { - if *complete { - return; - } + for (entity, flame_output) in flame_outputs.iter() { + commands.entity(entity) + .insert(HandledFlameOutput); - let image = images.get(&modnet.input).expect("failed to get image asset"); - - let mask_image: Result = (|| { - let onnx = onnx_assets.get(&modnet.onnx).ok_or("failed to get ONNX asset")?; - let session_lock = onnx.session.lock().map_err(|e| e.to_string())?; - let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?; - - Ok(modnet_inference(session, &[image], None).pop().unwrap()) - })(); - - match mask_image { - Ok(mask_image) => { - let mask_image = images.add(mask_image); - - commands.spawn(NodeBundle { - style: Style { - display: Display::Grid, - width: Val::Percent(100.0), - height: Val::Percent(100.0), - grid_template_columns: RepeatedGridTrack::flex(1, 1.0), - grid_template_rows: RepeatedGridTrack::flex(1, 1.0), - ..default() - }, - background_color: BackgroundColor(Color::DARK_GRAY), - ..default() - }) - .with_children(|builder| { - builder.spawn(ImageBundle { - style: Style { - ..default() - }, - image: UiImage::new(mask_image.clone()), - ..default() - }); - }); - - commands.spawn(Camera2dBundle::default()); - - *complete = true; - } - Err(e) => { - println!("inference failed: {}", e); - } + println!("{:?}", flame_output); } } - ``` diff --git a/benches/modnet.rs b/benches/modnet.rs index 853f23b..60e13df 100644 --- a/benches/modnet.rs +++ b/benches/modnet.rs @@ -23,6 +23,7 @@ use bevy_ort::{ modnet_output_to_luma_images, images_to_modnet_input, }, + OrtSession, Session, }; use ort::GraphOptimizationLevel; @@ -86,7 +87,7 @@ fn modnet_output_to_luma_images_benchmark(c: &mut Criterion) { let session = Session::builder().unwrap() .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() - .with_model_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap(); + .commit_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap(); let data = vec![0u8; (1920 * 1080 * 4) as usize]; let image: Image = Image::new( @@ -123,7 +124,8 @@ fn modnet_inference_benchmark(c: &mut Criterion) { let session = Session::builder().unwrap() .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() - .with_model_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap(); + .commit_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap(); + let session: bevy_ort::OrtSession = OrtSession::Session(session); MAX_RESOLUTIONS.iter().for_each(|(width, height)| { let data = vec![0u8; *width as usize * *height as usize * 4]; diff --git a/benches/yolo_v8.rs b/benches/yolo_v8.rs index 6e0439f..dbc29f0 100644 --- a/benches/yolo_v8.rs +++ b/benches/yolo_v8.rs @@ -23,6 +23,7 @@ use bevy_ort::{ process_output, yolo_inference, }, + OrtSession, Session, }; use ort::GraphOptimizationLevel; @@ -80,7 +81,7 @@ fn process_output_benchmark(c: &mut Criterion) { let session = Session::builder().unwrap() .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() - .with_model_from_file("assets/yolov8n.onnx").unwrap(); + .commit_from_file("assets/yolov8n.onnx").unwrap(); RESOLUTIONS.iter() .for_each(|(width, height)| { @@ -117,7 +118,8 @@ fn inference_benchmark(c: &mut Criterion) { let session = Session::builder().unwrap() .with_optimization_level(GraphOptimizationLevel::Level3).unwrap() - .with_model_from_file("assets/yolov8n.onnx").unwrap(); + .commit_from_file("assets/yolov8n.onnx").unwrap(); + let session = OrtSession::Session(session); RESOLUTIONS.iter().for_each(|(width, height)| { let data = vec![0u8; *width as usize * *height as usize * 4]; diff --git a/src/lib.rs b/src/lib.rs index 62038ed..09935c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,9 +51,57 @@ impl Plugin for BevyOrtPlugin { } -#[derive(Asset, Debug, Default, TypePath)] +pub enum OrtSession { + Session(ort::Session), + InMemory(ort::InMemorySession<'static>), +} + +impl OrtSession { + pub fn run<'s, 'i, 'v: 'i, const N: usize>( + &'s self, + input_values: impl Into>, + ) -> Result { + match self { + OrtSession::Session(session) => session.run(input_values), + OrtSession::InMemory(session) => session.run(input_values), + } + } + + pub fn inputs(&self) -> &Vec { + match self { + OrtSession::Session(session) => &session.inputs, + OrtSession::InMemory(session) => &session.inputs, + } + } + + pub fn outputs(&self) -> &Vec { + match self { + OrtSession::Session(session) => &session.outputs, + OrtSession::InMemory(session) => &session.outputs, + } + } +} + +#[derive(Asset, Default, TypePath)] pub struct Onnx { - pub session: Arc>>, + pub session_data: Vec, + pub session: Arc>>, +} + +impl Onnx { + pub fn from_session(session: Session) -> Self { + Self { + session_data: Vec::new(), + session: Arc::new(Mutex::new(Some(OrtSession::Session(session)))), + } + } + + pub fn from_in_memory(session: ort::InMemorySession<'static>) -> Self { + Self { + session_data: Vec::new(), + session: Arc::new(Mutex::new(Some(OrtSession::InMemory(session)))), + } + } } @@ -88,11 +136,9 @@ impl AssetLoader for OnnxLoader { // TODO: add session configuration let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_model_from_memory(&bytes)?; + .commit_from_memory(&bytes)?; - Ok(Onnx { - session: Arc::new(Mutex::new(Some(session))), - }) + Ok(Onnx::from_session(session)) }, _ => Err(BevyOrtError::Io(std::io::Error::new(ErrorKind::Other, "only .onnx supported"))), } diff --git a/src/models/flame.rs b/src/models/flame.rs new file mode 100644 index 0000000..f366552 --- /dev/null +++ b/src/models/flame.rs @@ -0,0 +1,237 @@ +use bevy::{ + prelude::*, + render::{ + mesh::{ + Indices, + Mesh, + Meshable, + PrimitiveTopology, + }, + render_asset::RenderAssetUsages, + }, +}; +use bytemuck::cast_slice; +use include_bytes_aligned::include_bytes_aligned; +use ndarray::Array2; +use serde::{Deserialize, Serialize}; + +use crate::{ + inputs, + Onnx, + OrtSession, +}; + + +pub static INDEX_BUFFER: &[u8] = include_bytes_aligned!(4, "flame_index_buffer.bin"); + + +pub struct FlamePlugin; +impl Plugin for FlamePlugin { + fn build(&self, app: &mut App) { + app.init_resource::(); + app.add_systems(PreUpdate, flame_inference_system); + } +} + + + +#[derive(Resource, Default)] +pub struct Flame { + pub onnx: Handle, +} + + +fn flame_inference_system( + mut commands: Commands, + flame: Res, + onnx_assets: Res>, + flame_inputs: Query< + ( + Entity, + &FlameInput, + ), + Without, + >, +) { + for (entity, flame_input) in flame_inputs.iter() { + let flame_output: Result = (|| { + let onnx = onnx_assets.get(&flame.onnx).ok_or("failed to get flame ONNX asset")?; + let session_lock = onnx.session.lock().map_err(|e| e.to_string())?; + let session = session_lock.as_ref().ok_or("failed to get flame session from flame ONNX asset")?; + + Ok(flame_inference( + session, + flame_input, + )) + })(); + + match flame_output { + Ok(flame_output) => { + commands.entity(entity) + .insert(flame_output); + } + Err(_e) => { + return; + } + } + } +} + + +const FLAME_BATCH_SIZE: usize = 1; + +#[derive( + Debug, + Clone, + Component, + Reflect, +)] +pub struct FlameInput { + pub shape: [[f32; 100]; FLAME_BATCH_SIZE], + pub pose: [[f32; 6]; FLAME_BATCH_SIZE], + pub expression: [[f32; 50]; FLAME_BATCH_SIZE], + pub neck: [[f32; 3]; FLAME_BATCH_SIZE], + pub eye: [[f32; 6]; FLAME_BATCH_SIZE], +} + +impl Default for FlameInput { + fn default() -> Self { + Self { + shape: [[0.0; 100]; FLAME_BATCH_SIZE], + pose: [[0.0; 6]; FLAME_BATCH_SIZE], + expression: [[0.0; 50]; FLAME_BATCH_SIZE], + neck: [[0.0; 3]; FLAME_BATCH_SIZE], + eye: [[0.0; 6]; FLAME_BATCH_SIZE], + } + } +} + + +#[derive( + Debug, + Clone, + Component, + Deserialize, + Serialize, + Reflect, +)] +pub struct FlameOutput { + pub vertices: Vec<[f32; 3]>, // TODO: use Vec3 for binding + // pub landmarks: Vec<[f32; 3]>, +} + +impl Default for FlameOutput { + fn default() -> Self { + Self { + vertices: vec![[0.0; 3]; 5023], + // landmarks: vec![[0.0; 3]; 68], + } + } +} + +impl Meshable for FlameOutput { + type Output = Mesh; + + fn mesh(&self) -> Self::Output { + let indices = Indices::U32(cast_slice(INDEX_BUFFER).to_vec()); + + Mesh::new( + PrimitiveTopology::TriangleList, + RenderAssetUsages::default(), + ) + .with_inserted_attribute(Mesh::ATTRIBUTE_POSITION, self.vertices.clone()) + .with_inserted_indices(indices) + } +} + + +pub fn flame_inference( + session: &OrtSession, + input: &FlameInput, +) -> FlameOutput { + let PreparedInput { + shape, + expression, + pose, + neck, + eye, + } = prepare_input(input); + + let input_values = inputs![ + "shape" => shape.view(), + "expression" => expression.view(), + "pose" => pose.view(), + "neck" => neck.view(), + "eye" => eye.view(), + ].map_err(|e| e.to_string()).unwrap(); + let outputs = session.run(input_values).map_err(|e| e.to_string()); + let binding = outputs.ok().unwrap(); + + let vertices: &ort::Value = binding.get("vertices").unwrap(); + // let landmarks: &ort::Value = binding.get("landmarks").unwrap(); + + post_process( + vertices, + // landmarks, + ) +} + + +pub struct PreparedInput { + pub shape: Array2, + pub pose: Array2, + pub expression: Array2, + pub neck: Array2, + pub eye: Array2, +} + +pub fn prepare_input( + input: &FlameInput, +) -> PreparedInput { + let shape = Array2::from_shape_vec((FLAME_BATCH_SIZE, 100), input.shape.concat()).unwrap(); + let pose = Array2::from_shape_vec((FLAME_BATCH_SIZE, 6), input.pose.concat()).unwrap(); + let expression = Array2::from_shape_vec((FLAME_BATCH_SIZE, 50), input.expression.concat()).unwrap(); + let neck = Array2::from_shape_vec((FLAME_BATCH_SIZE, 3), input.neck.concat()).unwrap(); + let eye = Array2::from_shape_vec((FLAME_BATCH_SIZE, 6), input.eye.concat()).unwrap(); + + PreparedInput { + shape, + expression, + pose, + neck, + eye, + } +} + + +pub fn post_process( + vertices: &ort::Value, + // landmarks: &ort::Value, +) -> FlameOutput { + let vertices_tensor = vertices.try_extract_tensor::().unwrap(); + let vertices_view = vertices_tensor.view(); // [FLAME_BATCH_SIZE, 5023, 3] + + // let landmarks_tensor = landmarks.try_extract_tensor::().unwrap(); + // let landmarks_view = landmarks_tensor.view(); // [FLAME_BATCH_SIZE, 68, 3] + + let vertices = vertices_view.outer_iter() + .flat_map(|subtensor| { + subtensor.outer_iter().map(|row| { + [row[0], row[1], row[2]] + }).collect::>() + }) + .collect::>(); + + // let landmarks = landmarks_view.outer_iter() + // .flat_map(|subtensor| { + // subtensor.outer_iter().map(|row| { + // [row[0], row[1], row[2]] + // }).collect::>() + // }) + // .collect::>(); + + FlameOutput { + vertices, + // landmarks, + } +} diff --git a/src/models/flame_index_buffer.bin b/src/models/flame_index_buffer.bin new file mode 100644 index 0000000..fef98f5 Binary files /dev/null and b/src/models/flame_index_buffer.bin differ diff --git a/src/models/lightglue.rs b/src/models/lightglue.rs index 911bc12..6180455 100644 --- a/src/models/lightglue.rs +++ b/src/models/lightglue.rs @@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize}; use crate::{ inputs, Onnx, + OrtSession, }; @@ -33,7 +34,7 @@ pub struct GluedPair { pub fn lightglue_inference( - session: &ort::Session, + session: &OrtSession, images: &[&Image], ) -> Vec<(usize, usize, Vec)> { let unique_unordered_pairs = images.iter().enumerate() @@ -100,13 +101,13 @@ pub fn post_process( kpts1: &ort::Value, matches: &ort::Value, ) -> Result, &'static str> { - let kpts0_tensor = kpts0.extract_tensor::().unwrap(); + let kpts0_tensor = kpts0.try_extract_tensor::().unwrap(); let kpts0_view = kpts0_tensor.view(); - let kpts1_tensor = kpts1.extract_tensor::().unwrap(); + let kpts1_tensor = kpts1.try_extract_tensor::().unwrap(); let kpts1_view = kpts1_tensor.view(); - let matches = matches.extract_tensor::().unwrap(); + let matches = matches.try_extract_tensor::().unwrap(); let matches_view = matches.view(); Ok( diff --git a/src/models/mod.rs b/src/models/mod.rs index 32e9875..6c07ed6 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "flame")] +pub mod flame; + #[cfg(feature = "lightglue")] pub mod lightglue; diff --git a/src/models/modnet.rs b/src/models/modnet.rs index d26b0ac..db4a0ec 100644 --- a/src/models/modnet.rs +++ b/src/models/modnet.rs @@ -16,6 +16,7 @@ use rayon::prelude::*; use crate::{ inputs, Onnx, + OrtSession, }; @@ -34,7 +35,7 @@ pub struct Modnet { pub fn modnet_inference( - session: &ort::Session, + session: &OrtSession, images: &[&Image], max_size: Option<(u32, u32)>, ) -> Vec { @@ -52,7 +53,7 @@ pub fn modnet_inference( pub fn modnet_output_to_luma_images( output_value: &ort::Value, ) -> Vec { - let tensor = output_value.extract_tensor::().unwrap(); + let tensor = output_value.try_extract_tensor::().unwrap(); let data = tensor.view(); let shape = data.shape(); diff --git a/src/models/yolo_v8.rs b/src/models/yolo_v8.rs index eb24863..e9c5164 100644 --- a/src/models/yolo_v8.rs +++ b/src/models/yolo_v8.rs @@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize}; use crate::{ inputs, Onnx, + OrtSession, }; @@ -35,15 +36,15 @@ pub struct Yolo { // TODO: support yolo input batching pub fn yolo_inference( - session: &ort::Session, + session: &OrtSession, image: &Image, iou_threshold: f32, ) -> Vec { let width = image.width(); let height = image.height(); - let model_width = session.inputs[0].input_type.tensor_dimensions().unwrap()[2] as u32; - let model_height = session.inputs[0].input_type.tensor_dimensions().unwrap()[3] as u32; + let model_width = session.inputs()[0].input_type.tensor_dimensions().unwrap()[2] as u32; + let model_height = session.inputs()[0].input_type.tensor_dimensions().unwrap()[3] as u32; let input = prepare_input(image, model_width, model_height); @@ -90,7 +91,7 @@ pub fn process_output( ) -> Vec { let mut boxes = Vec::new(); - let tensor = output.extract_tensor::().unwrap(); + let tensor = output.try_extract_tensor::().unwrap(); let data = tensor.view().t().into_owned(); for detection in data.axis_iter(Axis(0)) { diff --git a/tools/flame.rs b/tools/flame.rs new file mode 100644 index 0000000..815521f --- /dev/null +++ b/tools/flame.rs @@ -0,0 +1,78 @@ +use bevy::prelude::*; +use bevy_panorbit_camera::{ + PanOrbitCamera, + PanOrbitCameraPlugin, +}; + +use bevy_ort::{ + BevyOrtPlugin, + models::flame::{ + FlameInput, + FlameOutput, + Flame, + FlamePlugin, + }, +}; + + +fn main() { + App::new() + .add_plugins(( + DefaultPlugins, + BevyOrtPlugin, + FlamePlugin, + PanOrbitCameraPlugin, + )) + .add_systems(Startup, load_flame) + .add_systems(Startup, setup) + .add_systems(Update, on_flame_output) + .run(); +} + + +fn load_flame( + asset_server: Res, + mut flame: ResMut, +) { + flame.onnx = asset_server.load("models/flame.onnx"); +} + + +fn setup( + mut commands: Commands, +) { + commands.spawn(FlameInput::default()); + commands.spawn(( + Camera3dBundle::default(), + PanOrbitCamera { + allow_upside_down: true, + ..default() + }, + )); +} + + +#[derive(Debug, Component, Reflect)] +struct HandledFlameOutput; + +fn on_flame_output( + mut commands: Commands, + mut meshes: ResMut>, + flame_outputs: Query< + ( + Entity, + &FlameOutput, + ), + Without, + >, +) { + for (entity, flame_output) in flame_outputs.iter() { + commands.entity(entity) + .insert(HandledFlameOutput); + + commands.spawn(PbrBundle { + mesh: meshes.add(flame_output.mesh()), + ..default() + }); + } +}