diff --git a/Cargo.toml b/Cargo.toml index 2ca594e..734b7f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ import = ["bevy_args", "burn-import", "clap", "serde"] bevy_args = { git = "https://github.com/mosure/bevy_args.git", branch = "burn", optional = true } burn-import = { version = "0.15", features = ["pytorch"], optional = true } clap = { version = "4.5", features = ["derive"], optional = true } +ndarray = "0.16" serde = { version = "1.0", optional = true } @@ -110,7 +111,6 @@ bhtsne = "0.5.3" criterion = { version = "0.5", features = ["html_reports"] } futures-intrusive = { version = "0.5.0" } image = { version = "0.25", default-features = false, features = ["png"] } -ndarray = "0.16" pollster = { version = "0.4.0" } safetensors = "0.4" diff --git a/assets/models/dinov2.mpk b/assets/models/dinov2.mpk index eb8da7e..86098c5 100644 Binary files a/assets/models/dinov2.mpk and b/assets/models/dinov2.mpk differ diff --git a/assets/models/face_pca.mpk b/assets/models/face_pca.mpk new file mode 100644 index 0000000..a2b0188 Binary files /dev/null and b/assets/models/face_pca.mpk differ diff --git a/assets/models/pca.mpk b/assets/models/pca.mpk deleted file mode 100644 index 94ecba7..0000000 Binary files a/assets/models/pca.mpk and /dev/null differ diff --git a/crates/bevy_burn_dino/Cargo.toml b/crates/bevy_burn_dino/Cargo.toml index b51423f..5dc3d56 100644 --- a/crates/bevy_burn_dino/Cargo.toml +++ b/crates/bevy_burn_dino/Cargo.toml @@ -94,6 +94,9 @@ features = [ ] +[lib] +path = "src/lib.rs" + [[bin]] name = "bevy_burn_dino" path = "src/main.rs" diff --git a/crates/bevy_burn_dino/src/lib.rs b/crates/bevy_burn_dino/src/lib.rs new file mode 100644 index 0000000..d7b317a --- /dev/null +++ b/crates/bevy_burn_dino/src/lib.rs @@ -0,0 +1,134 @@ +use std::sync::{Arc, Mutex}; + +use burn::prelude::*; +use image::{ + DynamicImage, + ImageBuffer, + Rgb, + RgbImage, + RgbaImage, +}; + +use burn_dino::model::{ + dino::{ + DinoVisionTransformer, + DinoVisionTransformerConfig, + }, + pca::PcaTransform, +}; + +pub mod platform; + + +fn normalize( + input: Tensor, + device: &B::Device, +) -> Tensor { + let mean: Tensor = Tensor::from_floats([0.485, 0.456, 0.406], device); + let std: Tensor = Tensor::from_floats([0.229, 0.224, 0.225], device); + + input + .permute([0, 2, 3, 1]) + .sub(mean.unsqueeze()) + .div(std.unsqueeze()) + .permute([0, 3, 1, 2]) +} + +fn preprocess_image( + image: RgbImage, + config: &DinoVisionTransformerConfig, + device: &B::Device, +) -> Tensor { + let img = DynamicImage::ImageRgb8(image) + .resize_exact(config.image_size as u32, config.image_size as u32, image::imageops::FilterType::Triangle) + .to_rgb32f(); + + let samples = img.as_flat_samples(); + let floats: &[f32] = samples.as_slice(); + + let input: Tensor = Tensor::from_floats( + floats, + device, + ); + + let input = input.reshape([1, config.image_size, config.image_size, config.input_channels]) + .permute([0, 3, 1, 2]); + + normalize(input, device) +} + + +async fn to_image( + image: Tensor, + upsample_height: u32, + upsample_width: u32, +) -> RgbaImage { + let height = image.shape().dims[1]; + let width = image.shape().dims[2]; + + let image = image.to_data_async().await.to_vec::().unwrap(); + let image = ImageBuffer::, Vec>::from_raw( + width as u32, + height as u32, + image, + ).unwrap(); + + DynamicImage::ImageRgb32F(image) + .resize_exact(upsample_width, upsample_height, image::imageops::FilterType::Triangle) + .to_rgba8() +} + + +// TODO: benchmark process_frame +pub async fn process_frame( + input: RgbImage, + dino_config: DinoVisionTransformerConfig, + dino_model: Arc>>, + pca_model: Arc>>, + device: B::Device, +) -> Vec { + let input_tensor: Tensor = preprocess_image(input, &dino_config, &device); + + let dino_features = { + let model = dino_model.lock().unwrap(); + model.forward(input_tensor.clone(), None).x_norm_patchtokens + }; + + let batch = dino_features.shape().dims[0]; + let elements = dino_features.shape().dims[1]; + let embedding_dim = dino_features.shape().dims[2]; + let n_samples = batch * elements; + let spatial_size = elements.isqrt(); + + let x = dino_features.reshape([n_samples, embedding_dim]); + + let mut pca_features = { + let pca_transform = pca_model.lock().unwrap(); + pca_transform.forward(x.clone()) + }; + + // pca min-max scaling + for i in 0..3 { + let slice = pca_features.clone().slice([0..n_samples, i..i+1]); + let slice_min = slice.clone().min(); + let slice_max = slice.clone().max(); + let scaled = slice + .sub(slice_min.clone().unsqueeze()) + .div((slice_max - slice_min).unsqueeze()); + + pca_features = pca_features.slice_assign( + [0..n_samples, i..i+1], + scaled, + ); + } + + let pca_features = pca_features.reshape([batch, spatial_size, spatial_size, 3]); + + let pca_features = to_image( + pca_features, + dino_config.image_size as u32, + dino_config.image_size as u32, + ).await; + + pca_features.into_raw() +} diff --git a/crates/bevy_burn_dino/src/main.rs b/crates/bevy_burn_dino/src/main.rs index 81f3dc9..c47512c 100644 --- a/crates/bevy_burn_dino/src/main.rs +++ b/crates/bevy_burn_dino/src/main.rs @@ -37,18 +37,12 @@ use bevy_args::{ Deserialize, Parser, Serialize, + ValueEnum, }; use burn::{ prelude::*, backend::wgpu::{init_async, AutoGraphicsApi, Wgpu}, }; -use image::{ - DynamicImage, - ImageBuffer, - Rgb, - RgbImage, - RgbaImage, -}; use burn_dino::model::{ dino::{ @@ -60,8 +54,42 @@ use burn_dino::model::{ }, }; +use bevy_burn_dino::{ + platform::camera::{ + receive_image, + self, + }, + process_frame, +}; -// TODO: support multiple PCA heads with args/inspector switching + +#[derive( + Debug, + Default, + Clone, + PartialEq, + Serialize, + Deserialize, + Reflect, + ValueEnum, +)] +pub enum PcaType { + Adaptive, // TODO: window adaptive pca + #[default] + Face, + Person, +} + +impl PcaType { + #[allow(dead_code)] + const fn pca_weights_mpk(&self) -> &'static str { + match self { + PcaType::Adaptive => "adaptive_pca.mpk", + PcaType::Face => "face_pca.mpk", + PcaType::Person => "person_pca.mpk", + } + } +} #[derive( @@ -74,11 +102,8 @@ use burn_dino::model::{ Reflect, )] #[reflect(Resource)] -#[command(about = "burn_dino import", version, long_about = None)] -pub struct DinoImportConfig { - #[arg(long, default_value = "true")] - pub pca_only: bool, - +#[command(about = "bevy_burn_dino", version, long_about = None)] +pub struct BevyBurnDinoConfig { #[arg(long, default_value = "true")] pub press_esc_to_close: bool, @@ -90,16 +115,19 @@ pub struct DinoImportConfig { #[arg(long, default_value = "518")] pub inference_width: usize, + + #[arg(long, value_enum, default_value_t = PcaType::Face)] + pub pca_type: PcaType, } -impl Default for DinoImportConfig { +impl Default for BevyBurnDinoConfig { fn default() -> Self { Self { - pca_only: true, press_esc_to_close: true, show_fps: true, // TODO: display inference fps (UI fps is decoupled via async compute pool) inference_height: 518, inference_width: 518, + pca_type: PcaType::default(), } } } @@ -120,9 +148,11 @@ mod io { PcaTransform, PcaTransformConfig }, }; + use super::PcaType; static DINO_STATE_ENCODED: &[u8] = include_bytes!("../../../assets/models/dinov2.mpk"); - static PCA_STATE_ENCODED: &[u8] = include_bytes!("../../../assets/models/person_pca.mpk"); + static FACE_PCA_STATE_ENCODED: &[u8] = include_bytes!("../../../assets/models/face_pca.mpk"); + static PERSON_PCA_STATE_ENCODED: &[u8] = include_bytes!("../../../assets/models/person_pca.mpk"); pub async fn load_model( config: &DinoVisionTransformerConfig, @@ -138,10 +168,17 @@ mod io { pub async fn load_pca_model( config: &PcaTransformConfig, + pca_type: PcaType, device: &B::Device, ) -> PcaTransform { + let data = match pca_type { + PcaType::Adaptive => unimplemented!(), + PcaType::Face => FACE_PCA_STATE_ENCODED, + PcaType::Person => PERSON_PCA_STATE_ENCODED, + }; + let record = NamedMpkBytesRecorder::::default() - .load(PCA_STATE_ENCODED.to_vec(), &Default::default()) + .load(data.to_vec(), &Default::default()) .expect("failed to decode state"); let model= config.init(device); @@ -208,6 +245,7 @@ mod io { pub async fn load_pca_model( config: &PcaTransformConfig, + pca_type: PcaType, device: &B::Device, ) -> PcaTransform { let opts = RequestInit::new(); @@ -215,7 +253,7 @@ mod io { opts.set_mode(RequestMode::Cors); let request = Request::new_with_str_and_init( - "./assets/models/person_pca.mpk", + "./assets/models/" + pca_type.pca_weights_mpk(), &opts, ).unwrap(); @@ -239,236 +277,6 @@ mod io { } - -fn normalize( - input: Tensor, - device: &B::Device, -) -> Tensor { - let mean: Tensor = Tensor::from_floats([0.485, 0.456, 0.406], device); - let std: Tensor = Tensor::from_floats([0.229, 0.224, 0.225], device); - - input - .permute([0, 2, 3, 1]) - .sub(mean.unsqueeze()) - .div(std.unsqueeze()) - .permute([0, 3, 1, 2]) -} - -fn preprocess_image( - image: RgbImage, - config: &DinoVisionTransformerConfig, - device: &B::Device, -) -> Tensor { - let img = DynamicImage::ImageRgb8(image) - .resize_exact(config.image_size as u32, config.image_size as u32, image::imageops::FilterType::Triangle) - .to_rgb32f(); - - let samples = img.as_flat_samples(); - let floats: &[f32] = samples.as_slice(); - - let input: Tensor = Tensor::from_floats( - floats, - device, - ); - - let input = input.reshape([1, config.image_size, config.image_size, config.input_channels]) - .permute([0, 3, 1, 2]); - - normalize(input, device) -} - - -async fn to_image( - image: Tensor, - upsample_height: u32, - upsample_width: u32, -) -> RgbaImage { - let height = image.shape().dims[1]; - let width = image.shape().dims[2]; - - let image = image.to_data_async().await.to_vec::().unwrap(); - let image = ImageBuffer::, Vec>::from_raw( - width as u32, - height as u32, - image, - ).unwrap(); - - DynamicImage::ImageRgb32F(image) - .resize_exact(upsample_width, upsample_height, image::imageops::FilterType::Triangle) - .to_rgba8() -} - - -// TODO: benchmark process_frame -async fn process_frame( - input: RgbImage, - dino_config: DinoVisionTransformerConfig, - dino_model: Arc>>, - pca_model: Arc>>, - device: B::Device, -) -> Vec { - let input_tensor: Tensor = preprocess_image(input, &dino_config, &device); - - let dino_features = { - let model = dino_model.lock().unwrap(); - model.forward(input_tensor.clone(), None).x_norm_patchtokens - }; - - let batch = dino_features.shape().dims[0]; - let elements = dino_features.shape().dims[1]; - let embedding_dim = dino_features.shape().dims[2]; - let n_samples = batch * elements; - let spatial_size = elements.isqrt(); - - let x = dino_features.reshape([n_samples, embedding_dim]); - - let mut pca_features = { - let pca_transform = pca_model.lock().unwrap(); - pca_transform.forward(x.clone()) - }; - - // pca min-max scaling - for i in 0..3 { - let slice = pca_features.clone().slice([0..n_samples, i..i+1]); - let slice_min = slice.clone().min(); - let slice_max = slice.clone().max(); - let scaled = slice - .sub(slice_min.clone().unsqueeze()) - .div((slice_max - slice_min).unsqueeze()); - - pca_features = pca_features.slice_assign( - [0..n_samples, i..i+1], - scaled, - ); - } - - let pca_features = pca_features.reshape([batch, spatial_size, spatial_size, 3]); - - let pca_features = to_image( - pca_features, - dino_config.image_size as u32, - dino_config.image_size as u32, - ).await; - - pca_features.into_raw() -} - - -#[cfg(feature = "native")] -mod native { - use std::sync::{ - Arc, - Mutex, - mpsc::{ - self, - Sender, - SyncSender, - Receiver, - TryRecvError, - }, - }; - - use image::RgbImage; - use nokhwa::{ - nokhwa_initialize, - query, - CallbackCamera, - pixel_format::RgbFormat, - utils::{ - ApiBackend, - RequestedFormat, - RequestedFormatType, - }, - }; - use once_cell::sync::OnceCell; - - pub static SAMPLE_RECEIVER: OnceCell>>> = OnceCell::new(); - pub static SAMPLE_SENDER: OnceCell> = OnceCell::new(); - - pub static APP_RUN_RECEIVER: OnceCell>>> = OnceCell::new(); - pub static APP_RUN_SENDER: OnceCell> = OnceCell::new(); - - pub fn native_camera_thread() { - let ( - sample_sender, - sample_receiver, - ) = mpsc::sync_channel(1); - SAMPLE_RECEIVER.set(Arc::new(Mutex::new(sample_receiver))).unwrap(); - SAMPLE_SENDER.set(sample_sender).unwrap(); - - let ( - app_run_sender, - app_run_receiver, - ) = mpsc::channel(); - APP_RUN_RECEIVER.set(Arc::new(Mutex::new(app_run_receiver))).unwrap(); - APP_RUN_SENDER.set(app_run_sender).unwrap(); - - nokhwa_initialize(|granted| { - if !granted { - panic!("failed to initialize camera"); - } - }); - - let devices = query(ApiBackend::Auto).unwrap(); - let index = devices.first().unwrap().index(); - - let format = RequestedFormat::new::(RequestedFormatType::None); - let mut camera = CallbackCamera::new(index.clone(), format, |buffer| { - let image = buffer.decode_image::().unwrap(); - let sender = SAMPLE_SENDER.get().unwrap(); - sender.send(image).unwrap(); - }).unwrap(); - - camera.open_stream().unwrap(); - - loop { - camera.poll_frame().unwrap(); - - let receiver = APP_RUN_RECEIVER.get().unwrap(); - match receiver.lock().unwrap().try_recv() { - Ok(_) => break, - Err(TryRecvError::Empty) => continue, - Err(TryRecvError::Disconnected) => break, - }; - } - - camera.stop_stream().unwrap(); - } -} - - -#[cfg(feature = "web")] -mod web { - use std::cell::RefCell; - - use image::{ - DynamicImage, - RgbImage, - RgbaImage, - }; - use wasm_bindgen::prelude::*; - - thread_local! { - pub static SAMPLE_RECEIVER: RefCell> = RefCell::new(None); - } - - #[wasm_bindgen] - pub fn frame_input(pixel_data: &[u8], width: u32, height: u32) { - let rgba_image = RgbaImage::from_raw(width, height, pixel_data.to_vec()) - .expect("failed to create RgbImage"); - - // TODO: perform video element -> burn's webgpu texture conversion directly - // TODO: perform this conversion from tensors - let dynamic_image = DynamicImage::ImageRgba8(rgba_image); - let rgb_image: RgbImage = dynamic_image.to_rgb8(); - - SAMPLE_RECEIVER.with(|receiver| { - *receiver.borrow_mut() = Some(rgb_image); - }); - } -} - - #[derive(Resource)] struct DinoModel { config: DinoVisionTransformerConfig, @@ -491,28 +299,6 @@ struct PcaFeatures { #[derive(Component)] struct ProcessImage(Task); -#[cfg(feature = "native")] -fn receive_image() -> Option { - let receiver = native::SAMPLE_RECEIVER.get().unwrap(); - let mut last_image = None; - - { - let receiver = receiver.lock().unwrap(); - while let Ok(image) = receiver.try_recv() { - last_image = Some(image); - } - } - - last_image -} - -#[cfg(feature = "web")] -fn receive_image() -> Option { - web::SAMPLE_RECEIVER.with(|receiver| { - receiver.borrow_mut().take() - }) -} - fn process_frames( mut commands: Commands, dino_model: Res>, @@ -628,9 +414,7 @@ fn setup_ui( commands.spawn(Camera2d); } -pub fn viewer_app() -> App { - let args = parse_args::(); - +pub fn viewer_app(args: BevyBurnDinoConfig) -> App { let mut app = App::new(); app.insert_resource(args.clone()); @@ -770,6 +554,9 @@ fn fps_update_system( async fn run_app() { log("running app..."); + let args = parse_args::(); + log(&format!("{:?}", args)); + let device = Default::default(); init_async::(&device, Default::default()).await; @@ -782,15 +569,20 @@ async fn run_app() { log("dino model loaded"); + // TODO: support adaptive PCA let pca_config = PcaTransformConfig::new( config.embedding_dimension, 3, ); - let pca_transform = io::load_pca_model::(&pca_config, &device).await; + let pca_transform = io::load_pca_model::( + &pca_config, + args.pca_type.clone(), + &device, + ).await; log("pca model loaded"); - let mut app = viewer_app(); + let mut app = viewer_app(args); app.init_resource::(); app.insert_resource(DinoModel { @@ -834,7 +626,7 @@ pub fn log(_msg: &str) { fn main() { #[cfg(feature = "native")] { - std::thread::spawn(native::native_camera_thread); + std::thread::spawn(camera::native_camera_thread); futures::executor::block_on(run_app()); } diff --git a/crates/bevy_burn_dino/src/platform.rs b/crates/bevy_burn_dino/src/platform.rs new file mode 100644 index 0000000..7c60245 --- /dev/null +++ b/crates/bevy_burn_dino/src/platform.rs @@ -0,0 +1,133 @@ + +#[cfg(feature = "native")] +pub mod camera { + use std::sync::{ + Arc, + Mutex, + mpsc::{ + self, + Sender, + SyncSender, + Receiver, + TryRecvError, + }, + }; + + use image::RgbImage; + use nokhwa::{ + nokhwa_initialize, + query, + CallbackCamera, + pixel_format::RgbFormat, + utils::{ + ApiBackend, + RequestedFormat, + RequestedFormatType, + }, + }; + use once_cell::sync::OnceCell; + + pub static SAMPLE_RECEIVER: OnceCell>>> = OnceCell::new(); + pub static SAMPLE_SENDER: OnceCell> = OnceCell::new(); + + pub static APP_RUN_RECEIVER: OnceCell>>> = OnceCell::new(); + pub static APP_RUN_SENDER: OnceCell> = OnceCell::new(); + + pub fn native_camera_thread() { + let ( + sample_sender, + sample_receiver, + ) = mpsc::sync_channel(1); + SAMPLE_RECEIVER.set(Arc::new(Mutex::new(sample_receiver))).unwrap(); + SAMPLE_SENDER.set(sample_sender).unwrap(); + + let ( + app_run_sender, + app_run_receiver, + ) = mpsc::channel(); + APP_RUN_RECEIVER.set(Arc::new(Mutex::new(app_run_receiver))).unwrap(); + APP_RUN_SENDER.set(app_run_sender).unwrap(); + + nokhwa_initialize(|granted| { + if !granted { + panic!("failed to initialize camera"); + } + }); + + let devices = query(ApiBackend::Auto).unwrap(); + let index = devices.first().unwrap().index(); + + let format = RequestedFormat::new::(RequestedFormatType::None); + let mut camera = CallbackCamera::new(index.clone(), format, |buffer| { + let image = buffer.decode_image::().unwrap(); + let sender = SAMPLE_SENDER.get().unwrap(); + sender.send(image).unwrap(); + }).unwrap(); + + camera.open_stream().unwrap(); + + loop { + camera.poll_frame().unwrap(); + + let receiver = APP_RUN_RECEIVER.get().unwrap(); + match receiver.lock().unwrap().try_recv() { + Ok(_) => break, + Err(TryRecvError::Empty) => continue, + Err(TryRecvError::Disconnected) => break, + }; + } + + camera.stop_stream().unwrap(); + } + + pub fn receive_image() -> Option { + let receiver = SAMPLE_RECEIVER.get().unwrap(); + let mut last_image = None; + + { + let receiver = receiver.lock().unwrap(); + while let Ok(image) = receiver.try_recv() { + last_image = Some(image); + } + } + + last_image + } +} + +#[cfg(feature = "web")] +pub mod camera { + use std::cell::RefCell; + + use image::{ + DynamicImage, + RgbImage, + RgbaImage, + }; + use wasm_bindgen::prelude::*; + + thread_local! { + pub static SAMPLE_RECEIVER: RefCell> = RefCell::new(None); + } + + #[wasm_bindgen] + pub fn frame_input(pixel_data: &[u8], width: u32, height: u32) { + let rgba_image = RgbaImage::from_raw(width, height, pixel_data.to_vec()) + .expect("failed to create RgbImage"); + + // TODO: perform video element -> burn's webgpu texture conversion directly + let dynamic_image = DynamicImage::ImageRgba8(rgba_image); + let rgb_image: RgbImage = dynamic_image.to_rgb8(); + + SAMPLE_RECEIVER.with(|receiver| { + *receiver.borrow_mut() = Some(rgb_image); + }); + } + + pub fn receive_image() -> Option { + SAMPLE_RECEIVER.with(|receiver| { + receiver.borrow_mut().take() + }) + } +} + diff --git a/example/pca.rs b/example/pca.rs index a01a99a..f540ad4 100644 --- a/example/pca.rs +++ b/example/pca.rs @@ -24,7 +24,7 @@ use burn_dino::model::{ static DINO_STATE_ENCODED: &[u8] = include_bytes!("../assets/models/dinov2.mpk"); -static PCA_STATE_ENCODED: &[u8] = include_bytes!("../assets/models/pca.mpk"); +static PCA_STATE_ENCODED: &[u8] = include_bytes!("../assets/models/face_pca.mpk"); static INPUT_IMAGE_0: &[u8] = include_bytes!("../assets/images/dino_0.png"); static INPUT_IMAGE_1: &[u8] = include_bytes!("../assets/images/dino_1.png"); diff --git a/src/model/pca.rs b/src/model/pca.rs index d0456c1..d6b8a76 100644 --- a/src/model/pca.rs +++ b/src/model/pca.rs @@ -9,9 +9,6 @@ use burn::{ pub struct PcaTransformConfig { pub input_dim: usize, pub output_dim: usize, - - #[config(default = "Initializer::Constant{value:1e-5}")] - pub initializer: Initializer, } impl Default for PcaTransformConfig { @@ -27,9 +24,43 @@ impl PcaTransformConfig { } + +// mod linalg { +// use burn::{ +// prelude::*, +// backend::ndarray::{NdArray, NdArrayTensor}, +// tensor::TensorPrimitive, +// }; +// use ndarray::{ArrayBase, Dim, IxDynImpl, OwnedRepr}; + +// pub fn tensor_to_array( +// tensor: Tensor, +// ) -> ArrayBase, Dim> { +// let arr = Tensor::::from_data(tensor.into_data(), &Default::default()); +// let primitive: NdArrayTensor = arr.into_primitive().tensor(); +// primitive.array.to_owned() +// } + +// pub fn array_to_tensor( +// array: ArrayBase, Dim>, +// device: &B::Device, +// ) -> Tensor { +// let primitive: NdArrayTensor = NdArrayTensor::new(array.into()); +// let arr = Tensor::::from_primitive(TensorPrimitive::Float(primitive)); +// Tensor::::from_data(arr.into_data(), device) +// } + +// pub fn svd( +// x: ArrayBase, Dim>, +// ) -> ArrayBase, Dim> { +// let (u, s, vt) = x.svd(true, true).unwrap(); +// u.dot(&s.diag()).dot(&vt) +// } +// } + + #[derive(Module, Debug)] pub struct PcaTransform { - // pub auxillary_features: Param>, pub components: Param>, pub mean: Param>, } @@ -39,26 +70,28 @@ impl PcaTransform { device: &B::Device, config: &PcaTransformConfig, ) -> Self { - // let auxillary_features = config.initializer.init([config.batch_size - 1, config.input_dim], device); - let components = config.initializer.init([config.output_dim, config.input_dim], device); - let mean = config.initializer.init([1, config.input_dim], device); + let components = Initializer::Ones.init([config.output_dim, config.input_dim], device); + let mean = Initializer::Zeros.init([1, config.input_dim], device); Self { - // auxillary_features, components, mean, } } - pub fn forward(&self, x: Tensor) -> Tensor { - // let input_batch = Tensor::cat( - // vec![x, self.auxillary_features.val()], - // 0, - // ); - + pub fn forward( + &self, + x: Tensor, + ) -> Tensor { let transformed = x.matmul(self.components.val().transpose()); transformed - self.mean.val().matmul(self.components.val().transpose()) - - // TODO: remove the auxillary features } + + // pub fn rolling_fit( + // &mut self, + // x: Tensor, + // threshold: f32, + // ) { + + // } } diff --git a/tool/benchmark.rs b/tool/benchmark.rs index 613322b..36c9fd1 100644 --- a/tool/benchmark.rs +++ b/tool/benchmark.rs @@ -40,7 +40,7 @@ fn inference_benchmark(c: &mut Criterion) { let model = config.init(&device); let input: Tensor = Tensor::zeros([1, config.input_channels, config.image_size, config.image_size], &device); - b.iter(|| model.forward(input.clone(), None)); + b.iter(|| model.forward(input.clone(), None).x_norm_patchtokens.to_data()); }, ); } diff --git a/tool/import.rs b/tool/import.rs index 477fe81..c1b38ba 100644 --- a/tool/import.rs +++ b/tool/import.rs @@ -80,7 +80,7 @@ fn main() { // import safetensors -> mpk for PCA components, check if weights exist - let pca_weights = "./assets/models/dino_pca.pth"; + let pca_weights = "./assets/models/face_pca.pth"; let load_args = LoadArgs::new(pca_weights.into()) .with_debug_print(); diff --git a/tool/pca.py b/tool/pca.py index 3ab8b47..7f3b51d 100644 --- a/tool/pca.py +++ b/tool/pca.py @@ -32,12 +32,14 @@ def forward(self, x): ), ]) -prefix = 'person' +prefix = 'face' images = [ - f'./assets/images/{prefix}_0.png', - f'./assets/images/{prefix}_1.png', - f'./assets/images/{prefix}_2.png', - f'./assets/images/{prefix}_3.png', + f'./assets/images/{prefix}_0.webp', + f'./assets/images/{prefix}_1.webp', + f'./assets/images/{prefix}_2.webp', + f'./assets/images/{prefix}_3.webp', + f'./assets/images/{prefix}_4.webp', + f'./assets/images/{prefix}_5.webp', ] inputs = []