diff --git a/README.md b/README.md index 9c0a646..04efc2f 100644 --- a/README.md +++ b/README.md @@ -15,10 +15,20 @@ burn [depth pro](https://github.com/apple/ml-depth-pro) model inference ## usage ```rust -use burn_depth::model::depth_pro::DepthPro; +use burn::prelude::*; +use burn_depth::{InferenceBackend, model::depth_pro::DepthPro}; -let model = DepthPro::::load("assets/model/depth_pro.mpk")?; -let depth = model.forward(input); +// NdArray backend (alternatively: burn::backend::Cuda, burn::backend::Cpu) +let device = ::Device::default(); + +let model = DepthPro::::load(&device, "assets/model/depth_pro.mpk")?; + +// Image tensor with shape [1, 3, H, W] (batch, channels, height, width) +let input: Tensor = Tensor::zeros([1, 3, 512, 512], &device); + +let result = model.infer(input, None); +// result.depth: Tensor with shape [1, H, W] +// result.focallength_px: Tensor with shape [1] ```