+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions src/optim/adam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use std::{boxed::Box, marker::PhantomData};
/// lr: 1e-2,
/// betas: [0.5, 0.25],
/// eps: 1e-6,
/// weight_decay: Some(WeightDecay::Decoupled(1e-2)),
/// });
/// ```
///
Expand All @@ -51,6 +52,7 @@ pub struct Adam<M> {
/// lr: 1e-2,
/// betas: [0.1, 0.2],
/// eps: 1e-6,
/// weight_decay: Some(WeightDecay::L2(1e-1)),
/// };
/// ```
#[derive(Debug, Clone, Copy)]
Expand All @@ -63,6 +65,9 @@ pub struct AdamConfig {

/// Epsilon for numerical stability. Defaults to `1e-8`.
pub eps: f32,

/// Optional weight decay. Defaults to `None`.
pub weight_decay: Option<WeightDecay>,
}

impl Default for AdamConfig {
Expand All @@ -71,6 +76,7 @@ impl Default for AdamConfig {
lr: 1e-3,
betas: [0.9, 0.999],
eps: 1e-8,
weight_decay: None,
}
}
}
Expand Down Expand Up @@ -104,13 +110,23 @@ impl<M> GradientProvider for Adam<M> {
let mut g_t = self.gradients.remove(p)?;
let m_t = self.moment1.mut_gradient(p);
let v_t = self.moment2.mut_gradient(p);
if let Some(WeightDecay::L2(wd)) = self.cfg.weight_decay {
P::Device::foreach_mr(g_t.as_mut(), p.data(), &mut |g, p_el| {
*g += wd * p_el;
});
}
P::Device::foreach_mmm(g_t.as_mut(), m_t, v_t, &mut |g, m, v| {
*m = *m * self.cfg.betas[0] + *g * (1.0 - self.cfg.betas[0]);
*v = *v * self.cfg.betas[1] + g.powi(2) * (1.0 - self.cfg.betas[1]);
let m_hat = *m * (1.0 - self.cfg.betas[0].powi(self.t)).recip();
let v_hat = *v * (1.0 - self.cfg.betas[1].powi(self.t)).recip();
*g = self.cfg.lr * m_hat / (v_hat.sqrt() + self.cfg.eps)
});
if let Some(WeightDecay::Decoupled(wd)) = self.cfg.weight_decay {
P::Device::foreach_mr(g_t.as_mut(), p.data(), &mut |g, p_el| {
*g += wd * self.cfg.lr * p_el;
});
}
Some(g_t)
}
}
Expand Down Expand Up @@ -162,6 +178,7 @@ mod tests {
lr: 1e-3,
betas: [0.5, 0.25],
eps: 1e-8,
weight_decay: None,
});
let mut t: Tensor1D<5> = Tensor1D::ones();
let rate = Tensor1D::new([1e-4, 1e-3, 1e-2, 1e-1, 1e-0]);
Expand Down Expand Up @@ -199,6 +216,7 @@ mod tests {
lr: 1e-3,
betas: [0.9, 0.999],
eps: 1e-8,
weight_decay: None,
});

let py = model.forward(x.trace());
Expand All @@ -225,4 +243,62 @@ mod tests {
let g = backward(y.mean());
opt.update(&mut model, g).expect_err("");
}

#[test]
fn test_adam_l2_decay() {
let mut opt: Adam<Tensor1D<5>> = Adam::new(AdamConfig {
betas: [0.5, 0.25],
weight_decay: Some(WeightDecay::L2(1.0)),
..Default::default()
});
let mut t: Tensor1D<5> = tensor([-0.5, -0.25, 0.1, 0.6, 1.0]);
#[rustfmt::skip]
let expected = [
[-0.499, -0.249, 0.099, 0.59900004, 0.999],
[-0.49799952, -0.24797276, 0.09799955, 0.5979998, 0.9979998],
[-0.49699846, -0.24689871, 0.09699859, 0.5969993, 0.99699926],
[-0.49599692,-0.24575013,0.095997185,0.5959985,0.99599856],
[-0.49499503,-0.24448763,0.094995454,0.5949976,0.9949977],
[-0.4939929, -0.24382699, 0.09399351, 0.59399647, 0.9939967],
[-0.49299058, -0.24413459, 0.09299142, 0.5929953, 0.9929956],
[-0.49198818, -0.24478404, 0.09198925, 0.59199405, 0.9919945],
[-0.49098572, -0.24561276, 0.09098703, 0.5909928, 0.9909934],
[-0.48998323, -0.24548599, 0.08998477, 0.58999157, 0.9899922],
];

for e in expected.iter() {
let gradients = t.trace().exp().square().mean().backward();
opt.update(&mut t, gradients).expect("");
assert_eq!(t.data(), e);
}
}

#[test]
fn test_adam_decoupled_decay() {
let mut opt: Adam<Tensor1D<5>> = Adam::new(AdamConfig {
betas: [0.5, 0.25],
weight_decay: Some(WeightDecay::Decoupled(1.0)),
..Default::default()
});
let mut t: Tensor1D<5> = tensor([-0.5, -0.25, 0.1, 0.6, 1.0]);
#[rustfmt::skip]
let expected = [
[-0.5005, -0.25075, 0.098900005, 0.5984, 0.998],
[-0.5009996, -0.25149944, 0.09780081, 0.59680116, 0.9960015],
[-0.50149894, -0.25224838, 0.09670238, 0.59520346, 0.9940043],
[-0.5019978, -0.25299674, 0.09560476, 0.59360695, 0.9920086],
[-0.50249636, -0.2537445, 0.09450804, 0.5920117, 0.99001455],
[-0.5029944, -0.25449163, 0.09341227, 0.59041786, 0.98802227],
[-0.50349206, -0.25523806, 0.092317514, 0.58882546, 0.9860318],
[-0.5039892, -0.25598377, 0.0912238, 0.5872346, 0.9840432],
[-0.5044859, -0.25672877, 0.09013115, 0.5856453, 0.98205656],
[-0.50498205, -0.25747302, 0.08903958, 0.58405757, 0.9800719],
];

for e in expected.iter() {
let gradients = t.trace().exp().square().mean().backward();
opt.update(&mut t, gradients).expect("");
assert_eq!(t.data(), e);
}
}
}
2 changes: 2 additions & 0 deletions src/optim/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ mod adam;
mod optimizer;
mod rmsprop;
mod sgd;
mod weight_decay;

pub use adam::*;
pub use optimizer::*;
pub use rmsprop::*;
pub use sgd::*;
pub use weight_decay::*;
11 changes: 0 additions & 11 deletions src/optim/sgd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,6 @@ pub enum Momentum {
Nesterov(f32),
}

/// WeightDecay used for [Sgd]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum WeightDecay {
/// Weight decay applied to the gradients before any momentum updates. Equivalent to L2 regularization.
L2(f32),

/// Weight decay applied after any momentum updates, without modifying the gradients.
/// See [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)
Decoupled(f32),
}

impl<M> Default for Sgd<M> {
/// See [SgdConfig]
fn default() -> Self {
Expand Down
10 changes: 10 additions & 0 deletions src/optim/weight_decay.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/// L2 and decoupled regularization methods
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum WeightDecay {
/// Weight decay applied to the gradients before any momentum updates. Equivalent to L2 regularization.
L2(f32),

/// Weight decay applied after any momentum updates, without modifying the gradients.
/// See [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)
Decoupled(f32),
}
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载