这是indexloc提供的服务,不要输入任何密码
Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "burn_dino"
description = "burn dinov2 model inference and training"
version = "0.3.1"
edition = "2021"
edition = "2024"
authors = ["mosure <mitchell@mosure.me>"]
license = "MIT OR Apache-2.0"
keywords = [
Expand Down Expand Up @@ -44,44 +44,50 @@ import = ["bevy_args", "burn-import", "clap", "serde"]

[dependencies]
# bevy_args = { version = "1.6", optional = true }
bevy_args = { git = "https://github.com/mosure/bevy_args.git", branch = "burn", optional = true }
burn-import = { version = "0.15", features = ["pytorch"], optional = true }
burn-import = { version = "0.16", features = ["pytorch"], optional = true }
clap = { version = "4.5", features = ["derive"], optional = true }
ndarray = "0.16"
# ndarray = "0.16"
serde = { version = "1.0", optional = true }


[dependencies.bevy_args]
git = "https://github.com/mosure/bevy_args.git"
branch = "burn"
optional = true


[dependencies.burn]
version = "0.15"
default-features = false
version = "0.16"
default-features = true
features = [
# "autotune",
# "fusion",
"fusion",
"ndarray",
"std",
"wgpu",
]

[dependencies.burn-wgpu]
version = "0.15"
default-features = false
version = "0.16"
default-features = true
features = [
"fusion",
# "fusion",
"std",
# "template",
]

[dependencies.cubecl]
version = "0.3"
default-features = false
version = "0.4"
default-features = true
features = [
"linalg",
"std",
# "template",
]

[dependencies.cubecl-runtime]
version = "0.3"
default-features = false
version = "0.4"
default-features = true
features = [
"std",
"channel-mpsc",
Expand Down Expand Up @@ -112,7 +118,7 @@ criterion = { version = "0.5", features = ["html_reports"] }
futures-intrusive = { version = "0.5.0" }
image = { version = "0.25", default-features = false, features = ["png"] }
pollster = { version = "0.4.0" }
safetensors = "0.4"
safetensors = "0.5"


[profile.dev.package."*"]
Expand Down
Binary file modified assets/models/dinov2.mpk
Binary file not shown.
11 changes: 6 additions & 5 deletions crates/bevy_burn_dino/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ serde = "1.0"

# TODO: ideally, bevy and burn synchronize wgpu versions upstream
[dependencies.bevy]
# version = "0.14"
# version = "0.16"
git = "https://github.com/mosure/bevy.git"
branch = "burn"
rev = "669d139c13f6b44652f38131a5c9d20ca54e024d"
default-features = false
features = [
"bevy_asset",
Expand All @@ -60,11 +60,12 @@ features = [
]

[dependencies.burn]
version = "0.15"
default-features = false
version = "0.16"
default-features = true
features = [
# "autotune",
# "fusion",
"fusion",
"ndarray",
"std",
# "template",
"wgpu",
Expand Down
14 changes: 8 additions & 6 deletions crates/bevy_burn_dino/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use bevy::{
DiagnosticPath,
Diagnostics,
DiagnosticsStore,
FrameTimeDiagnosticsPlugin,
// FrameTimeDiagnosticsPlugin,
RegisterDiagnostic,
},
ecs::{system::SystemState, world::CommandQueue},
Expand All @@ -25,12 +25,14 @@ use bevy::{
WgpuSettings,
},
RenderPlugin,
}, tasks::{
},
tasks::{
block_on,
futures_lite::future,
AsyncComputeTaskPool,
Task,
},
ui::widget::NodeImageMode,
};
use bevy_args::{
parse_args,
Expand All @@ -41,7 +43,7 @@ use bevy_args::{
};
use burn::{
prelude::*,
backend::wgpu::{init_async, AutoGraphicsApi, Wgpu},
backend::wgpu::{init_setup_async, AutoGraphicsApi, Wgpu},
};

use burn_dino::model::{
Expand Down Expand Up @@ -401,7 +403,7 @@ fn setup_ui(
..default()
})
.with_children(|builder| {
builder.spawn(UiImage {
builder.spawn(ImageNode {
image: pca_image.image.clone(),
image_mode: NodeImageMode::Stretch,
..default()
Expand Down Expand Up @@ -480,7 +482,7 @@ pub fn viewer_app(args: BevyBurnDinoConfig) -> App {
}

if args.show_fps {
app.add_plugins(FrameTimeDiagnosticsPlugin);
// app.add_plugins(FrameTimeDiagnosticsPlugin::default());
app.register_diagnostic(Diagnostic::new(INFERENCE_FPS));
app.add_systems(Startup, fps_display_setup);
app.add_systems(Update, fps_update_system);
Expand Down Expand Up @@ -555,7 +557,7 @@ async fn run_app() {
log(&format!("{:?}", args));

let device = Default::default();
init_async::<AutoGraphicsApi>(&device, Default::default()).await;
init_setup_async::<AutoGraphicsApi>(&device, Default::default()).await;

log("device created");

Expand Down
92 changes: 54 additions & 38 deletions example/tsne.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use image::{
RgbImage,
};
use bhtsne::tSNE;
use ndarray::Array2;

use burn_dino::model::dino::{
DinoVisionTransformer,
Expand All @@ -34,7 +33,7 @@ pub fn load_model<B: Backend>(
.load(STATE_ENCODED.to_vec(), &Default::default())
.expect("failed to decode state");

let model= config.init(device);
let model = config.init(device);
model.load_record(record)
}

Expand All @@ -53,14 +52,19 @@ fn normalize<B: Backend>(
.permute([0, 3, 1, 2])
}


pub fn load_image<B: Backend>(
bytes: &[u8],
config: &DinoVisionTransformerConfig,
device: &B::Device,
) -> Tensor<B, 4> {
let img = load_from_memory_with_format(bytes, ImageFormat::Png)
.unwrap()
.resize_exact(config.image_size as u32, config.image_size as u32, image::imageops::FilterType::Lanczos3);
.resize_exact(
config.image_size as u32,
config.image_size as u32,
image::imageops::FilterType::Lanczos3,
);

let img = match img {
DynamicImage::ImageRgb8(img) => img,
Expand All @@ -72,12 +76,13 @@ pub fn load_image<B: Backend>(
.flat_map(|p| p.0.iter().map(|&c| c as f32 / 255.0))
.collect();

let input: Tensor<B, 1> = Tensor::from_floats(
img_data.as_slice(),
device,
);

let input = input.reshape([1, config.input_channels, config.image_size, config.image_size]);
let input: Tensor<B, 1> = Tensor::from_floats(img_data.as_slice(), device);
let input = input.reshape([
1,
config.input_channels,
config.image_size,
config.image_size,
]);

normalize(input, device)
}
Expand All @@ -90,7 +95,12 @@ fn main() {
};
let dino = load_model(&config, &device);

let input_pngs = vec![INPUT_IMAGE_0, INPUT_IMAGE_1, INPUT_IMAGE_2, INPUT_IMAGE_3];
let input_pngs = vec![
INPUT_IMAGE_0,
INPUT_IMAGE_1,
INPUT_IMAGE_2,
INPUT_IMAGE_3,
];

let mut input_tensors = Vec::new();
for input in input_pngs {
Expand All @@ -105,57 +115,63 @@ fn main() {
let elements = output.shape().dims[1];
let features = output.shape().dims[2];
let n_samples = batch * elements;

let spatial_size = elements.isqrt();

let x = output.reshape([n_samples, features]);
let binding = x.to_data()
.to_vec::<f32>()
.unwrap();
let data: Vec<&[f32]> = binding
.chunks(config.embedding_dimension)
.collect();
let data: Vec<&[f32]> = binding.chunks(config.embedding_dimension).collect();

let tsne_features = tSNE::new(&data)
let mut tsne_features = tSNE::new(&data)
.embedding_dim(3)
.perplexity(10.0)
.epochs(1000)
.barnes_hut(0.5, |sample_a, sample_b| {
sample_a.iter()
sample_a
.iter()
.zip(sample_b.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt()
})
.embedding();
let mut tsne_features = Array2::from_shape_vec((n_samples, 3), tsne_features).unwrap();

for mut col in tsne_features.columns_mut() {
let min = col.fold(f32::INFINITY, |a, &b| a.min(b));
let max = col.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let range = max - min;
col.mapv_inplace(|x| (x - min) / range);
let num_dims = 3;
for d in 0..num_dims {
let mut min_val = f32::INFINITY;
let mut max_val = f32::NEG_INFINITY;
for i in 0..n_samples {
let idx = i * num_dims + d;
let value = tsne_features[idx];
if value < min_val {
min_val = value;
}
if value > max_val {
max_val = value;
}
}
let range = if max_val - min_val == 0.0 { 1.0 } else { max_val - min_val };
for i in 0..n_samples {
let idx = i * num_dims + d;
tsne_features[idx] = (tsne_features[idx] - min_val) / range;
}
}

let tsne_features = tsne_features.to_shape([batch, spatial_size, spatial_size, 3]).unwrap();

for (i, img) in tsne_features.outer_iter().enumerate() {
let collected: Vec<u8> = img.iter()
.map(|&x| (x * 255.0)
.max(0.0)
.min(255.0) as u8
).collect();
let img = RgbImage::from_raw(
spatial_size as u32,
spatial_size as u32,
collected,
)
.unwrap();
for b in 0..batch {
let start = b * spatial_size * spatial_size * num_dims;
let end = start + spatial_size * spatial_size * num_dims;
let mut collected: Vec<u8> = Vec::with_capacity((spatial_size * spatial_size * 3) as usize);
for &value in &tsne_features[start..end] {
let pixel = (value * 255.0).max(0.0).min(255.0) as u8;
collected.push(pixel);
}

let output_directory = std::path::Path::new("output/tsne");
std::fs::create_dir_all(output_directory).unwrap();

let output_path = output_directory.join(format!("{}.png", i));
let output_path = output_directory.join(format!("{}.png", b));
let img = RgbImage::from_raw(spatial_size as u32, spatial_size as u32, collected)
.unwrap();
img.save(output_path).unwrap();
}
}