+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cargo-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: test
args: --features test-f64
args: --features test-f64,safetensors
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ cblas-sys = { version = "0.1.4", default-features = false, optional = true }
libc = { version = "0.2", default-features = false, optional = true }
cudarc = { version = "0.8.0", default-features = false, optional = true }
num-traits = { version = "0.2.15", default-features = false }
safetensors = { version = "0.3", default-features = false, optional = true }
memmap2 = { version = "0.5", default-features = false, optional = true }

[dev-dependencies]
tempfile = "3.3.0"
Expand All @@ -50,6 +52,7 @@ threaded-cpu = ["std", "matrixmultiply/threading"]
fast-alloc = ["std"]
nightly = []
numpy = ["dep:zip", "std"]
safetensors = ["dep:safetensors", "std", "dep:memmap2"]
cblas = ["dep:cblas-sys", "dep:libc"]
intel-mkl = ["cblas"]
cuda = ["dep:cudarc", "dep:glob"]
Expand All @@ -71,4 +74,4 @@ harness = false

[[bench]]
name = "softmax"
harness = false
harness = false
44 changes: 44 additions & 0 deletions examples/safetensors-save-load.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//! Demonstrates how to save and load arrays with safetensors

#[cfg(feature = "safetensors")]
fn main() {
use ::safetensors::SafeTensors;
use dfdx::{
prelude::*,
tensor::{AsArray, Cpu},
};
use memmap2::MmapOptions;
let dev: Cpu = Default::default();

type Model = (Linear<5, 10>, Linear<10, 5>);
let model = dev.build_module::<Model, f32>();
model
.save_safetensors("model.safetensors")
.expect("Failed to save model");

let mut model2 = dev.build_module::<Model, f32>();
model2
.load_safetensors("model.safetensors")
.expect("Failed to load model");

assert_eq!(model.0.weight.array(), model2.0.weight.array());

// ADVANCED USAGE to load pre-existing models

// wget -O gpt2.safetensors https://huggingface.co/gpt2/resolve/main/model.safetensors

let mut gpt2 = dev.build_module::<Linear<728, 50257>, f32>();
let filename = "gpt2.safetensors";
let f = std::fs::File::open(filename).expect("Couldn't read file, have you downloaded gpt2 ? `wget -O gpt2.safetensors https://huggingface.co/gpt2/resolve/main/model.safetensors`");
let buffer = unsafe { MmapOptions::new().map(&f).expect("Could not mmap") };
let tensors = SafeTensors::deserialize(&buffer).expect("Couldn't read safetensors file");

gpt2.weight
.load_safetensor(&tensors, "wte.weight")
.expect("Could not load tensor");
Comment on lines +36 to +38
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome example

}

#[cfg(not(feature = "safetensors"))]
fn main() {
panic!("Use the 'safetensors' feature to run this example");
}
21 changes: 21 additions & 0 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@
//! state_dict = {k: torch.from_numpy(v) for k, v in np.load("dfdx-model.npz").items()}
//! mlp.load_state_dict(state_dict)
//! ```
//!
//! The feature `safetensors` allows to do the same with
//! [https://github.com/huggingface/safetensors]()
//! Call [SaveToSafetensors::save()] and [LoadFromSafetensors::load()] traits. All modules provided here implement it,
//! including tuples. These all save to/from `.safetensors` files, which are flat layout with JSON
//! header, allowing for super fast loads (with memory mapping).
//!
//! This is implemented to be fairly portable. For example you can use
//! [https://github.com/huggingface/transformers]()
//!
//! ```python
//! from transformers import pipeline
//!
//! pipe = pipeline(model="gpt2")
//! pipe.save_pretrained("my_local", safe_serialization=True)
//! # This created `my_local/model.safetensors` file which can now be used.
//! ```

mod num_params;
mod reset_params;
Expand All @@ -128,13 +145,17 @@ mod pool2d;
mod pool_global;
mod repeated;
mod residual;
#[cfg(feature = "safetensors")]
mod safetensors;
mod split_into;
mod transformer;
mod unbiased_linear;
mod zero_grads;

pub use module::*;

#[cfg(feature = "safetensors")]
pub use crate::nn::safetensors::{LoadFromSafetensors, SaveToSafetensors};
pub use ema::ModelEMA;
#[cfg(feature = "numpy")]
pub use npz::{LoadFromNpz, SaveToNpz};
Expand Down
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载