-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Labels
Description
burn_dino/crates/bevy_burn_dino/src/main.rs
Line 410 in 0a59eb7
// TODO: move to config |
}
}
last_image
}
#[cfg(feature = "web")]
fn receive_image() -> Option<RgbImage> {
web::SAMPLE_RECEIVER.with(|receiver| {
receiver.borrow_mut().take()
})
}
fn process_frames(
mut commands: Commands,
dino_model: Res<DinoModel<Wgpu>>,
pca_transform: Res<PcaTransformModel<Wgpu>>,
pca_features_handle: Res<PcaFeatures>,
active_tasks: Query<&ProcessImage>,
) {
// TODO: move to config
// TODO: fix multiple in flight deadlock
let inference_max_in_flight = 1;
if active_tasks.iter().count() >= inference_max_in_flight {
return;
}
if let Some(image) = receive_image() {
let thread_pool = AsyncComputeTaskPool::get();
let entity = commands.spawn_empty().id();
let device = dino_model.device.clone();
let dino_config = dino_model.config.clone();
let dino_model = dino_model.model.clone();
let image_handle = pca_features_handle.image.clone();
let pca_model = pca_transform.model.clone();
let task = thread_pool.spawn(async move {
let img_data = process_frame(
image,
dino_config,
dino_model,
pca_model,
device,
).await;
let mut command_queue = CommandQueue::default();
command_queue.push(move |world: &mut World| {
let mut system_state = SystemState::<
ResMut<Assets<Image>>,
>::new(world);
let mut images = system_state.get_mut(world);
// TODO: share wgpu io between bevy/burn
let existing_image = images.get_mut(&image_handle).unwrap();
existing_image.data = img_data;
world
.entity_mut(entity)
.remove::<ProcessImage>();
});
command_queue
});
commands.entity(entity).insert(ProcessImage(task));
}
}
fn handle_tasks(mut commands: Commands, mut transform_tasks: Query<&mut ProcessImage>) {
for mut task in &mut transform_tasks {
if let Some(mut commands_queue) = block_on(future::poll_once(&mut task.0)) {
commands.append(&mut commands_queue);
}
}
}