这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ www/assets/

mediamtx/
onnxruntime/

*.onnx
12 changes: 10 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "bevy_ort"
description = "bevy ort (onnxruntime) plugin"
version = "0.9.0"
version = "0.10.0"
edition = "2021"
authors = ["mosure <mitchell@mosure.me>"]
license = "MIT"
Expand Down Expand Up @@ -29,11 +29,13 @@ default-run = "modnet"

[features]
default = [
"flame",
"lightglue",
"modnet",
"yolo_v8",
]

flame = []
lightglue = []
modnet = ["rayon"]
yolo_v8 = []
Expand All @@ -60,11 +62,12 @@ features = [
"bevy_winit",
"multi-threaded",
"png",
"tonemapping_luts",
]


[dependencies.ort]
version = "2.0.0-rc.0"
version = "2.0.0-rc.2"
default-features = false
features = [
"cuda",
Expand Down Expand Up @@ -93,6 +96,11 @@ opt-level = 3
path = "src/lib.rs"


[[bin]]
name = "flame"
path = "tools/flame.rs"
required-features = ["flame"]

[[bin]]
name = "lightglue"
path = "tools/lightglue.rs"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ a bevy plugin for the [ort](https://docs.rs/ort/latest/ort/) library
- [X] lightglue (feature matching)
- [X] modnet (photographic portrait matting)
- [X] yolo_v8 (object detection)
- [X] flame (parametric head model)


## library usage
Expand Down
4 changes: 2 additions & 2 deletions benches/modnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ fn modnet_output_to_luma_images_benchmark(c: &mut Criterion) {

let session = Session::builder().unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
.with_model_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap();
.commit_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap();

let data = vec![0u8; (1920 * 1080 * 4) as usize];
let image: Image = Image::new(
Expand Down Expand Up @@ -123,7 +123,7 @@ fn modnet_inference_benchmark(c: &mut Criterion) {

let session = Session::builder().unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
.with_model_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap();
.commit_from_file("assets/modnet_photographic_portrait_matting.onnx").unwrap();

MAX_RESOLUTIONS.iter().for_each(|(width, height)| {
let data = vec![0u8; *width as usize * *height as usize * 4];
Expand Down
4 changes: 2 additions & 2 deletions benches/yolo_v8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ fn process_output_benchmark(c: &mut Criterion) {

let session = Session::builder().unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
.with_model_from_file("assets/yolov8n.onnx").unwrap();
.commit_from_file("assets/yolov8n.onnx").unwrap();

RESOLUTIONS.iter()
.for_each(|(width, height)| {
Expand Down Expand Up @@ -117,7 +117,7 @@ fn inference_benchmark(c: &mut Criterion) {

let session = Session::builder().unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
.with_model_from_file("assets/yolov8n.onnx").unwrap();
.commit_from_file("assets/yolov8n.onnx").unwrap();

RESOLUTIONS.iter().for_each(|(width, height)| {
let data = vec![0u8; *width as usize * *height as usize * 4];
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl AssetLoader for OnnxLoader {
// TODO: add session configuration
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_model_from_memory(&bytes)?;
.commit_from_memory(&bytes)?;

Ok(Onnx {
session: Arc::new(Mutex::new(Some(session))),
Expand Down
163 changes: 163 additions & 0 deletions src/models/flame.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
use bevy::prelude::*;
use ndarray::Array2;
use serde::{Deserialize, Serialize};

use crate::{
inputs,
Onnx,
};



pub struct FlamePlugin;
impl Plugin for FlamePlugin {
fn build(&self, app: &mut App) {
app.init_resource::<Flame>();
}
}

#[derive(Resource, Default)]
pub struct Flame {
pub onnx: Handle<Onnx>,
}


#[derive(
Debug,
Clone,
)]
pub struct FlameInput {
pub shape: [[f32; 100]; 8],
pub pose: [[f32; 6]; 8],
pub expression: [[f32; 50]; 8],
pub neck: [[f32; 3]; 8],
pub eye: [[f32; 6]; 8],
}

impl Default for FlameInput {
fn default() -> Self {
let radian = std::f32::consts::PI / 180.0;

Self {
shape: [[0.0; 100]; 8],
pose: [
[0.0, 30.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, -30.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, 85.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, -48.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, 10.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, -15.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0 * radian, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0 * radian, 0.0, 0.0, 0.0, 0.0],
],
expression: [[0.0; 50]; 8],
neck: [[0.0; 3]; 8],
eye: [[0.0; 6]; 8],
}
}
}


#[derive(
Debug,
Default,
Clone,
Deserialize,
Serialize,
)]
pub struct FlameOutput {
pub vertices: Vec<[f32; 3]>, // TODO: use Vec3 for binding
pub landmarks: Vec<[f32; 3]>,
}


pub fn flame_inference(
session: &ort::Session,
input: &FlameInput,
) -> FlameOutput {
let PreparedInput {
shape,
expression,
pose,
neck,
eye,
} = prepare_input(input);

let input_values = inputs![
"shape" => shape.view(),
"expression" => expression.view(),
"pose" => pose.view(),
"neck" => neck.view(),
"eye" => eye.view(),
].map_err(|e| e.to_string()).unwrap();
let outputs = session.run(input_values).map_err(|e| e.to_string());
let binding = outputs.ok().unwrap();

let vertices: &ort::Value = binding.get("vertices").unwrap();
let landmarks: &ort::Value = binding.get("landmarks").unwrap();

post_process(
vertices,
landmarks,
)
}


pub struct PreparedInput {
pub shape: Array2<f32>,
pub pose: Array2<f32>,
pub expression: Array2<f32>,
pub neck: Array2<f32>,
pub eye: Array2<f32>,
}

pub fn prepare_input(
input: &FlameInput,
) -> PreparedInput {
let shape = Array2::from_shape_vec((8, 100), input.shape.concat()).unwrap();
let pose = Array2::from_shape_vec((8, 6), input.pose.concat()).unwrap();
let expression = Array2::from_shape_vec((8, 50), input.expression.concat()).unwrap();
let neck = Array2::from_shape_vec((8, 3), input.neck.concat()).unwrap();
let eye = Array2::from_shape_vec((8, 6), input.eye.concat()).unwrap();

PreparedInput {
shape,
expression,
pose,
neck,
eye,
}
}


pub fn post_process(
vertices: &ort::Value,
landmarks: &ort::Value,
) -> FlameOutput {
let vertices_tensor = vertices.try_extract_tensor::<f32>().unwrap();
let vertices_view = vertices_tensor.view(); // [8, 5023, 3]

let landmarks_tensor = landmarks.try_extract_tensor::<f32>().unwrap();
let landmarks_view = landmarks_tensor.view(); // [8, 68, 3]

let vertices = vertices_view.outer_iter()
.flat_map(|subtensor| {
subtensor.outer_iter().map(|row| {
[row[0], row[1], row[2]]
}).collect::<Vec<[f32; 3]>>()
})
.collect::<Vec::<_>>();

let landmarks = landmarks_view.outer_iter()
.flat_map(|subtensor| {
subtensor.outer_iter().map(|row| {
[row[0], row[1], row[2]]
}).collect::<Vec<[f32; 3]>>()
})
.collect::<Vec::<_>>();

FlameOutput {
vertices,
landmarks,
}
}
6 changes: 3 additions & 3 deletions src/models/lightglue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ pub fn post_process(
kpts1: &ort::Value,
matches: &ort::Value,
) -> Result<Vec<GluedPair>, &'static str> {
let kpts0_tensor = kpts0.extract_tensor::<i64>().unwrap();
let kpts0_tensor = kpts0.try_extract_tensor::<i64>().unwrap();
let kpts0_view = kpts0_tensor.view();

let kpts1_tensor = kpts1.extract_tensor::<i64>().unwrap();
let kpts1_tensor = kpts1.try_extract_tensor::<i64>().unwrap();
let kpts1_view = kpts1_tensor.view();

let matches = matches.extract_tensor::<i64>().unwrap();
let matches = matches.try_extract_tensor::<i64>().unwrap();
let matches_view = matches.view();

Ok(
Expand Down
3 changes: 3 additions & 0 deletions src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#[cfg(feature = "flame")]
pub mod flame;

#[cfg(feature = "lightglue")]
pub mod lightglue;

Expand Down
2 changes: 1 addition & 1 deletion src/models/modnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub fn modnet_inference(
pub fn modnet_output_to_luma_images(
output_value: &ort::Value,
) -> Vec<Image> {
let tensor = output_value.extract_tensor::<f32>().unwrap();
let tensor = output_value.try_extract_tensor::<f32>().unwrap();
let data = tensor.view();

let shape = data.shape();
Expand Down
2 changes: 1 addition & 1 deletion src/models/yolo_v8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pub fn process_output(
) -> Vec<BoundingBox> {
let mut boxes = Vec::new();

let tensor = output.extract_tensor::<f32>().unwrap();
let tensor = output.try_extract_tensor::<f32>().unwrap();
let data = tensor.view().t().into_owned();

for detection in data.axis_iter(Axis(0)) {
Expand Down
70 changes: 70 additions & 0 deletions tools/flame.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use bevy::prelude::*;

use bevy_ort::{
BevyOrtPlugin,
models::flame::{
FlameInput,
FlameOutput,
flame_inference,
Flame,
FlamePlugin,
},
Onnx,
};


fn main() {
App::new()
.add_plugins((
DefaultPlugins,
BevyOrtPlugin,
FlamePlugin,
))
.add_systems(Startup, load_flame)
.add_systems(Update, inference)
.run();
}


fn load_flame(
asset_server: Res<AssetServer>,
mut flame: ResMut<Flame>,
) {
let flame_handle: Handle<Onnx> = asset_server.load("models/flame.onnx");
flame.onnx = flame_handle;
}


fn inference(
mut commands: Commands,
flame: Res<Flame>,
onnx_assets: Res<Assets<Onnx>>,
mut complete: Local<bool>,
) {
if *complete {
return;
}

let flame_output: Result<FlameOutput, String> = (|| {
let onnx = onnx_assets.get(&flame.onnx).ok_or("failed to get ONNX asset")?;
let session_lock = onnx.session.lock().map_err(|e| e.to_string())?;
let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?;

Ok(flame_inference(
session,
&FlameInput::default(),
))
})();

match flame_output {
Ok(_flame_output) => {
// TODO: insert mesh
// TODO: insert pan orbit camera
commands.spawn(Camera3dBundle::default());
*complete = true;
}
Err(e) => {
eprintln!("inference failed: {}", e);
}
}
}