+
Skip to content

Optimizing softmax & log_softmax #660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions src/tensor_ops/log_softmax.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{BroadcastTo, Device, LogSumExpTo, TrySub};
use super::*;
use crate::{shapes::*, tensor::*};

/// `log(softmax(t))` in numerically stable way across `Ax`. Does `t - logsumexp(t)` under the hood.
Expand Down Expand Up @@ -42,15 +42,63 @@ impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Tensor<S, E, D, T> {
where
S: ReduceShape<Ax>,
{
let logsumexp = self.retaped::<T>().try_logsumexp::<_, Ax>()?;
let logsumexp = logsumexp.try_broadcast_like(self.shape())?;
self.try_sub(logsumexp)
/*
# Notes on this reduction

log_softmax is equivalent to:
`t - t.logsumexp()`

logsumexp can be inlined to:
`t - ((t - t.max()).exp().sum().ln() + t.max())`

we can apply the subtraction in the following way:
`t - (t - t.max()).exp().sum().ln() - t.max()`
`t - t.max() - (t - t.max()).exp().sum().ln()`

Notice there is a repeated expression here of `t - t.max()`.
So we can re-use this calculation.
`tm - tm.exp().sum().ln()`
*/
let shape = *self.shape();
let (t, tape) = self.split_tape();
let max = t.clone().try_max::<_, Ax>()?;
let tm = {
// Do this calculation off of the tape
let keep_id = t.id;
let mut t = t.try_sub(max.try_broadcast_like::<_, Ax>(&shape)?)?;
t.id = keep_id;
t.put_tape(tape)
};
let logsumexp = tm.retaped::<T>().try_exp()?.try_sum::<_, Ax>()?.try_ln()?;
tm.try_sub(logsumexp.try_broadcast_like(&shape)?)
}
}

#[cfg(test)]
mod tests {
use crate::{shapes::Axis, tensor::*, tensor_ops::*, tests::*};
use crate::{shapes::*, tensor::*, tensor_ops::*, tests::*};

#[test]
fn test_log_softmax_equivalence() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank4<8, 16, 32, 64>, TestDtype, _> = dev.sample_normal();
let p = t.leaky_trace().log_softmax::<Axis<3>>();
let p_truth = t.leaky_trace() - t.leaky_trace().logsumexp::<_, Axis<3>>().broadcast();
// we can't create an array as it will overflow the stack
for (p_i, pt_i) in p.as_vec().iter().zip(p_truth.as_vec().iter()) {
assert!((p_i - pt_i).abs() <= TestDtype::DEFAULT_TOLERANCE);
}
let g = p.square().mean().backward();
let g_truth = p_truth.square().mean().backward();
for (g_i, gt_i) in g
.get(&t)
.as_vec()
.iter()
.zip(g_truth.get(&t).as_vec().iter())
{
assert!((g_i - gt_i).abs() <= TestDtype::DEFAULT_TOLERANCE);
}
}

#[test]
fn test_log_softmax_1d() {
Expand Down
69 changes: 67 additions & 2 deletions src/tensor_ops/softmax.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::Device;
use super::*;
use crate::{shapes::*, tensor::*};

/// Computes the [softmax function](https://en.wikipedia.org/wiki/Softmax_function) across
Expand Down Expand Up @@ -37,14 +37,79 @@ impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Tensor<S, E, D, T> {
where
S: ReduceShape<Ax>,
{
self.try_log_softmax::<Ax>()?.try_exp()
/*
# Notes on this reduction

Softmax is equivalent to:
`t.log_softmax().exp()`

which when given the log_softmax reductions is equivalent to:
`(t - t.logsumexp()).exp()`

logsumexp can be inlined to:
`(t - ((t - t.max()).exp().sum().ln() + t.max())).exp()`

we can apply the subtraction in the following way:
`(t - (t - t.max()).exp().sum().ln() - t.max()).exp()`
`(t - t.max() - (t - t.max()).exp().sum().ln()).exp()`

Notice there is a repeated expression here of `t - t.max()`.
So we can re-use this calculation. Let's denote this expression tm:
`(tm - tm.exp().sum().ln()).exp()`

Another reduction is the identity of the form `e^(x - y)` = `e^x / e^y`.
`tm.exp() / tm.exp().sum().ln().exp()`

First we can re-use the `tm.exp()` calculation - lets call it tme
`tme / tme.sum().ln().exp()`

And finally we know that `t.ln().exp()` is equivalent to `t`. I.e. they are
fused
`tme / tme.sum()`
*/
let shape = *self.shape();
let (t, tape) = self.split_tape();
let max = t.clone().try_max::<_, Ax>()?;
let t = {
// in place subtraction of max since we don't want to record this
// on the auto diff graph.
let keep_id = t.id;
let mut t = t.try_sub(max.try_broadcast_like::<_, Ax>(&shape)?)?;
t.id = keep_id;
t
};
let t_exp = t.put_tape(tape).try_exp()?;
let t_expsum = t_exp.retaped::<T>().try_sum::<_, Ax>()?;
t_exp.try_div(t_expsum.try_broadcast_like(&shape)?)
}
}

#[cfg(test)]
mod tests {
use crate::{shapes::*, tensor::*, tensor_ops::*, tests::*};

#[test]
fn test_softmax_equivalence() {
let dev: TestDevice = Default::default();
let t: Tensor<Rank4<8, 16, 32, 64>, TestDtype, _> = dev.sample_normal();
let p = t.leaky_trace().softmax::<Axis<3>>();
let p_truth = t.leaky_trace().log_softmax::<Axis<3>>().exp();
// we can't create an array as it will overflow the stack
for (p_i, pt_i) in p.as_vec().iter().zip(p_truth.as_vec().iter()) {
assert!((p_i - pt_i).abs() <= TestDtype::DEFAULT_TOLERANCE);
}
let g = p.square().mean().backward();
let g_truth = p_truth.square().mean().backward();
for (g_i, gt_i) in g
.get(&t)
.as_vec()
.iter()
.zip(g_truth.get(&t).as_vec().iter())
{
assert!((g_i - gt_i).abs() <= TestDtype::DEFAULT_TOLERANCE);
}
}

#[test]
fn test_softmax_1d() {
let dev: TestDevice = Default::default();
Expand Down
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载