+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/nn/dropout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use super::*;
/// let grads = dropout.alloc_grads();
/// let x: Tensor<Rank2<2, 5>, f32, _> = dev.ones();
/// let r = dropout.forward_mut(x.trace(grads));
/// assert_eq!(r.array(), [[2.0, 2.0, 2.0, 0.0, 0.0], [2.0, 2.0, 0.0, 0.0, 2.0]]);
/// assert_eq!(r.array(), [[2.0, 0.0, 2.0, 0.0, 2.0], [0.0, 2.0, 0.0, 2.0, 2.0]]);
/// ```
#[derive(Clone, Debug, Default)]
pub struct DropoutOneIn<const N: usize>;
Expand Down Expand Up @@ -115,7 +115,7 @@ impl<const N: usize, S: Shape, E: Dtype, D: Device<E>> ModuleMut<Tensor<S, E, D,
/// let grads = dropout.alloc_grads();
/// let x: Tensor<Rank2<2, 5>, f32, _> = dev.ones();
/// let r = dropout.forward_mut(x.trace(grads));
/// assert_eq!(r.array(), [[2.0, 2.0, 2.0, 0.0, 0.0], [2.0, 2.0, 0.0, 0.0, 2.0]]);
/// assert_eq!(r.array(), [[2.0, 0.0, 2.0, 0.0, 2.0], [0.0, 2.0, 0.0, 2.0, 2.0]]);
/// ```
#[derive(Clone, Debug)]
pub struct Dropout {
Expand Down
25 changes: 11 additions & 14 deletions src/tensor_ops/dropout/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,17 @@ use crate::{
};

use num_traits::Float;
use rand::{rngs::StdRng, Rng, SeedableRng};
use rand_distr::{Distribution, Standard};
use rand::{rngs::StdRng, SeedableRng};
use rand_distr::{Bernoulli, Distribution};

impl<E: Float + Dtype> super::DropoutKernel<E> for Cpu
where
Standard: Distribution<E>,
{
impl<E: Float + Dtype> super::DropoutKernel<E> for Cpu {
fn forward<S: Shape>(
&self,
op: super::DropoutKernelOp<E>,
op: super::DropoutKernelOp,
inp: &Tensor<S, E, Self>,
) -> Result<Tensor<S, E, Self>, Self::Err> {
let mut rng = StdRng::seed_from_u64(op.seed);
let dist = Bernoulli::new(op.prob).unwrap();
let mut out = Tensor {
id: unique_id(),
data: inp.data.clone(),
Expand All @@ -26,32 +24,31 @@ where
tape: Default::default(),
};
for x in out.buf_iter_mut() {
let val: E = rng.sample(Standard);
*x = if val < op.prob {
*x = if dist.sample(&mut rng) {
E::zero()
} else {
*x / (E::one() - op.prob)
*x / E::from_f64(1.0 - op.prob).unwrap()
};
}
Ok(out)
}

fn backward<S: Shape>(
&self,
op: super::DropoutKernelOp<E>,
op: super::DropoutKernelOp,
inp: &Tensor<S, E, Self>,
grad_inp: &mut Self::Vec<E>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err> {
let mut rng = StdRng::seed_from_u64(op.seed);
let dist = Bernoulli::new(op.prob).unwrap();
debug_assert_eq!(grad_inp.len(), grad_out.len());
debug_assert_eq!(inp.data.len(), grad_out.len());
for (i, data_i) in grad_inp.iter_mut().enumerate() {
let val: E = rng.sample(Standard);
*data_i += if val < op.prob {
*data_i += if dist.sample(&mut rng) {
E::zero()
} else {
(E::one() - op.prob).recip()
E::from_f64((1.0 - op.prob).recip()).unwrap()
} * grad_out[i];
}
Ok(())
Expand Down
34 changes: 19 additions & 15 deletions src/tensor_ops/dropout/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use std::vec::Vec;

use cudarc::driver::{DeviceSlice, LaunchAsync};

use rand::{rngs::StdRng, Rng, SeedableRng};
use rand_distr::{Distribution, Standard};
use rand::{rngs::StdRng, SeedableRng};
use rand_distr::{Bernoulli, Distribution};

const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/dropout.ptx"));

Expand Down Expand Up @@ -36,20 +36,22 @@ impl HasCudaKernel<f64> for Cuda {
impl<E: Dtype> super::DropoutKernel<E> for Cuda
where
Self: HasCudaKernel<E>,
Standard: Distribution<E>,
{
fn forward<S: Shape>(
&self,
op: super::DropoutKernelOp<E>,
op: super::DropoutKernelOp,
inp: &Tensor<S, E, Self>,
) -> Result<Tensor<S, E, Self>, Self::Err> {
let noise = {
let mask = {
let mut rng = StdRng::seed_from_u64(op.seed);
let mut noise: Vec<E> = Vec::with_capacity(inp.data.len());
noise.resize_with(inp.data.len(), || rng.sample(Standard));
self.dev.htod_copy(noise)
let dist = Bernoulli::new(op.prob).unwrap();
let mut mask: Vec<bool> = Vec::with_capacity(inp.data.len());
mask.resize_with(inp.data.len(), || dist.sample(&mut rng));
self.dev.htod_copy(mask)
}?;

let prob = E::from_f64(op.prob).unwrap();

if !self.dev.has_func(Self::MOD, Self::FNS[0]) {
self.dev.load_ptx(PTX_SRC.into(), Self::MOD, Self::FNS)?;
}
Expand All @@ -59,27 +61,29 @@ where

let fwd_fn = self.dev.get_func(Self::MOD, Self::FNS[0]).unwrap();
let cfg = launch_cfg::<128>(numel as u32);
let params = (op.prob, numel, inp.data.as_ref(), &noise, &mut storage);
let params = (prob, numel, inp.data.as_ref(), &mask, &mut storage);
unsafe { fwd_fn.launch(cfg, params) }?;
Ok(self.build_tensor(inp.shape, inp.strides, storage))
}
fn backward<S: Shape>(
&self,
op: super::DropoutKernelOp<E>,
op: super::DropoutKernelOp,
inp: &Tensor<S, E, Self>,
grad_inp: &mut Self::Vec<E>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err> {
let noise = {
let mask = {
let mut rng = StdRng::seed_from_u64(op.seed);
let mut noise: Vec<E> = Vec::with_capacity(inp.data.len());
noise.resize_with(inp.data.len(), || rng.sample(Standard));
self.dev.htod_copy(noise)
let dist = Bernoulli::new(op.prob).unwrap();
let mut mask: Vec<bool> = Vec::with_capacity(inp.data.len());
mask.resize_with(inp.data.len(), || dist.sample(&mut rng));
self.dev.htod_copy(mask)
}?;
let prob = E::from_f64(op.prob).unwrap();
let bwd_fn = self.dev.get_func(Self::MOD, Self::FNS[1]).unwrap();
let numel = inp.data.len();
let cfg = launch_cfg::<128>(numel as u32);
let params = (op.prob, numel, &noise, grad_inp, grad_out);
let params = (prob, numel, &mask, grad_inp, grad_out);
unsafe { bwd_fn.launch(cfg, params) }?;
Ok(())
}
Expand Down
8 changes: 4 additions & 4 deletions src/tensor_ops/dropout/dropout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ extern "C" __global__ void FWD( \
const TYPENAME prob, \
const size_t numel, \
const TYPENAME *inp, \
const TYPENAME *noise, \
const bool *mask, \
TYPENAME *out \
) { \
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; \
Expand All @@ -14,13 +14,13 @@ extern "C" __global__ void FWD( \
} \
TYPENAME zero = 0.0; \
TYPENAME one = 1.0; \
TYPENAME scalar = (noise[i] < prob) ? zero : (one / (one - prob)); \
TYPENAME scalar = mask[i] ? zero : (one / (one - prob)); \
out[i] = inp[i] * scalar; \
} \
extern "C" __global__ void BWD( \
const TYPENAME prob, \
const size_t numel, \
const TYPENAME *noise, \
const bool *mask, \
TYPENAME *grad_inp, \
const TYPENAME *grad_out \
) { \
Expand All @@ -30,7 +30,7 @@ extern "C" __global__ void BWD( \
} \
TYPENAME zero = 0.0; \
TYPENAME one = 1.0; \
grad_inp[i] += (noise[i] < prob) ? zero : (grad_out[i] / (one - prob)); \
grad_inp[i] += mask[i] ? zero : (grad_out[i] / (one - prob)); \
}

DROPOUT(__half, dropout_fwd_f16, dropout_bwd_f16);
Expand Down
46 changes: 25 additions & 21 deletions src/tensor_ops/dropout/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ use crate::{

#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct DropoutKernelOp<F> {
pub struct DropoutKernelOp {
pub seed: u64,
pub prob: F,
pub prob: f64,
}

pub trait DropoutKernel<E: Dtype>: DeviceStorage {
fn forward<S: Shape>(
&self,
op: DropoutKernelOp<E>,
op: DropoutKernelOp,
inp: &Tensor<S, E, Self>,
) -> Result<Tensor<S, E, Self>, Self::Err>;
fn backward<S: Shape>(
&self,
op: DropoutKernelOp<E>,
op: DropoutKernelOp,
inp: &Tensor<S, E, Self>,
grad_inp: &mut Self::Vec<E>,
grad_out: &Self::Vec<E>,
Expand All @@ -40,7 +40,7 @@ pub trait DropoutKernel<E: Dtype>: DeviceStorage {
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([1.0f32, 2.0, 3.0, 4.0]);
/// let r = t.dropout(0.5);
/// assert_eq!(r.array(), [2.0, 4.0, 6.0, 0.0]);
/// assert_eq!(r.array(), [2.0, 0.0, 6.0, 0.0]);
/// ```
///
/// ### Implementation details:
Expand All @@ -64,7 +64,7 @@ impl<S: Shape, E: Dtype, D: DropoutKernel<E>, T: Tape<E, D>> Tensor<S, E, D, T>
/// See [dropout]
pub fn try_dropout(self, prob: impl Into<f64>) -> Result<Self, D::Err> {
let seed = self.device.random_u64();
let prob = E::from_f64(prob.into()).unwrap();
let prob = prob.into();
let op = DropoutKernelOp { seed, prob };
let (inp, mut tape) = self.split_tape();
let out = inp.device.forward(op, &inp)?;
Expand All @@ -87,44 +87,48 @@ mod tests {
#[test]
fn test_dropout_all_0d() {
let dev: TestDevice = Default::default();
let t: Tensor<_, f32, _> = dev.tensor(3.0);
let t = dev.tensor([3.0, 4.0, 5.0]).to_dtype::<TestDtype>();
let r = t.leaky_trace().dropout(1.0);
assert_close_to_literal!(r, 0.0);
let g = r.backward();
assert_close_to_literal!(g.get(&t), 0.0);
assert_close_to_literal!(r, [0.0, 0.0, 0.0]);
let g = r.sum().backward();
assert_close_to_literal!(g.get(&t), [0.0, 0.0, 0.0]);
}

#[test]
fn test_dropout_none_0d() {
let dev: TestDevice = Default::default();
let t: Tensor<_, f32, _> = dev.tensor(3.0);
let t = dev.tensor([3.0, 4.0, 5.0]).to_dtype::<TestDtype>();
let r = t.leaky_trace().dropout(0.0);
assert_close_to_literal!(r, 3.0);
let g = r.backward();
assert_close_to_literal!(g.get(&t), 1.0);
assert_close_to_literal!(r, [3.0, 4.0, 5.0]);
let g = r.sum().backward();
assert_close_to_literal!(g.get(&t), [1.0, 1.0, 1.0]);
}

#[test]
fn test_dropout_1d_with_non_positive_values() {
let dev: TestDevice = Default::default();
let t: Tensor<_, f32, _> = dev.tensor([0.0, 2.0, -3.0, -4.0, 0.0]);
let dev: TestDevice = TestDevice::seed_from_u64(1);
let t = dev
.tensor([-0.0, 2.0, -3.0, -4.0, -0.0])
.to_dtype::<TestDtype>();
let r = t.leaky_trace().dropout(0.5);
assert_close_to_literal!(r, [0.0, 4.0, -6.0, 0.0, 0.0]);
assert_close_to_literal!(r, [0.0, 4.0, -6.0, -8.0, 0.0]);
let g = r.mean().backward();
assert_close_to_literal!(g.get(&t), [0.4, 0.4, 0.4, 0.0, 0.0]);
assert_close_to_literal!(g.get(&t), [0.4, 0.4, 0.4, 0.4, 0.0]);
}

#[test]
fn test_dropout_2d() {
let dev: TestDevice = Default::default();
let t: Tensor<_, f32, _> = dev.tensor([[0.05, 0.1, -0.2], [0.3, -0.4, 0.5]]);
let t = dev
.tensor([[0.05, 0.1, -0.2], [0.3, -0.4, 0.5]])
.to_dtype::<TestDtype>();
let r = t.leaky_trace().dropout(0.6);
assert_close_to_literal!(r, [[0.125, 0.25, -0.5], [0.0, 0.0, 1.25]]);
assert_close_to_literal!(r, [[0.125, 0.0, -0.5], [0.0, -1.0, 0.0]]);
// NOTE: .exp() so we ensure result grad is used properly
let g = r.exp().mean().backward();
assert_close_to_literal!(
g.get(&t),
[[0.47214523, 0.5350107, 0.2527211], [0.0, 0.0, 1.4543099]]
[[0.4721452, 0.0, 0.25272113], [0.0, 0.1532831, 0.0]]
);
}
}
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载