-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
Description
bevy_ort/src/models/yolo_v8.rs
Line 36 in 904f83c
| // TODO: support yolo input batching |
}
pub struct YoloPlugin;
impl Plugin for YoloPlugin {
fn build(&self, app: &mut App) {
app.init_resource::<Yolo>();
}
}
#[derive(Resource, Default)]
pub struct Yolo {
pub onnx: Handle<Onnx>,
}
// TODO: support yolo input batching
pub fn yolo_inference(
session: &ort::Session,
image: &Image,
iou_threshold: f32,
) -> Vec<BoundingBox> {
let width = image.width();
let height = image.height();
let model_width = session.inputs[0].input_type.tensor_dimensions().unwrap()[2] as u32;
let model_height = session.inputs[0].input_type.tensor_dimensions().unwrap()[3] as u32;
let input = prepare_input(image, model_width, model_height);
let input_values = inputs!["images" => &input.as_standard_layout()].map_err(|e| e.to_string()).unwrap();
let outputs = session.run(input_values).map_err(|e| e.to_string());
let binding = outputs.ok().unwrap();
let output_value: &ort::Value = binding.get("output0").unwrap();
let detections = process_output(output_value, width, height, model_width, model_height);
nms(&detections, iou_threshold)
}
pub fn prepare_input(
image: &Image,
model_width: u32,