+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 90 additions & 42 deletions src/nn/transformer/decoder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,72 @@
use crate::prelude::*;
use rand::Rng;
use std::io::{Read, Seek, Write};
use zip::{result::ZipResult, ZipArchive, ZipWriter};

/// **Requires Nightly** A transformer decoder.
///
/// Generics
/// - `MODEL_DIM`: The size of query/key/value tensors. Given to [MultiHeadAttention].
/// - `NUM_HEADS`: The number of heads in [MultiHeadAttention].
/// - `FF_DIM`: The size of the hidden layer in
/// the feedforward network in [TransformerDecoderBlock].
/// - `NUM_LAYERS`: The number of [TransformerDecoderBlock] to use.
/// TODO: Doctests
#[derive(Clone, Debug, Default)]
pub struct TransformerDecoder<
const MODEL_DIM: usize,
const NUM_HEADS: usize,
const FF_DIM: usize,
const NUM_LAYERS: usize,
>(pub Repeated<TransformerDecoderBlock<MODEL_DIM, NUM_HEADS, FF_DIM>, NUM_LAYERS>);

impl<const M: usize, const H: usize, const F: usize, const L: usize> ResetParams
for TransformerDecoder<M, H, F, L>
{
fn reset_params<R: Rng>(&mut self, rng: &mut R) {
self.0.reset_params(rng);
}
}

impl<const M: usize, const H: usize, const F: usize, const L: usize> CanUpdateWithGradients
for TransformerDecoder<M, H, F, L>
{
fn update<G: GradientProvider>(&mut self, grads: &mut G, unused: &mut UnusedTensors) {
self.0.update(grads, unused);
}
}

impl<const M: usize, const H: usize, const F: usize, const L: usize, Tgt, Mem> Module<(Tgt, Mem)>
for TransformerDecoder<M, H, F, L>
where
Mem: Tensor<NoTape = Mem>,
TransformerDecoderBlock<M, H, F>: Module<(Tgt, Mem), Output = Tgt>,
{
type Output = Tgt;

fn forward(&self, (mut x, mem): (Tgt, Mem)) -> Self::Output {
for block in self.0.modules.iter() {
x = block.forward((x, mem.duplicate()));
}
x
}
}

impl<const M: usize, const H: usize, const F: usize, const L: usize> SaveToNpz
for TransformerDecoder<M, H, F, L>
{
fn write<W: Write + Seek>(&self, pre: &str, w: &mut ZipWriter<W>) -> ZipResult<()> {
self.0.write(pre, w)
}
}

impl<const M: usize, const H: usize, const F: usize, const L: usize> LoadFromNpz
for TransformerDecoder<M, H, F, L>
{
fn read<R: Read + Seek>(&mut self, pre: &str, r: &mut ZipArchive<R>) -> Result<(), NpzError> {
self.0.read(pre, r)
}
}

/// **Requires Nightly** A transformer decoder block. Different than the normal transformer block
/// as this self attention accepts an additional sequence from the encoder.
Expand All @@ -16,7 +83,7 @@ use rand::Rng;
/// )
/// ```
/// TODO: Doctests
#[derive(Default, Debug)]
#[derive(Clone, Default, Debug)]
pub struct TransformerDecoderBlock<
const MODEL_DIM: usize,
const NUM_HEADS: usize,
Expand Down Expand Up @@ -89,52 +156,33 @@ where
}
}

/// **Requires Nightly** A transformer decoder.
///
/// Generics
/// - `MODEL_DIM`: The size of query/key/value tensors. Given to [MultiHeadAttention].
/// - `NUM_HEADS`: The number of heads in [MultiHeadAttention].
/// - `FF_DIM`: The size of the hidden layer in
/// the feedforward network in [TransformerDecoderBlock].
/// - `NUM_LAYERS`: The number of [TransformerDecoderBlock] to use.
/// TODO: Doctests
#[derive(Debug, Default)]
pub struct TransformerDecoder<
const MODEL_DIM: usize,
const NUM_HEADS: usize,
const FF_DIM: usize,
const NUM_LAYERS: usize,
>(pub Repeated<TransformerDecoderBlock<MODEL_DIM, NUM_HEADS, FF_DIM>, NUM_LAYERS>);

impl<const M: usize, const H: usize, const F: usize, const L: usize> ResetParams
for TransformerDecoder<M, H, F, L>
{
fn reset_params<R: Rng>(&mut self, rng: &mut R) {
self.0.reset_params(rng);
}
}

impl<const M: usize, const H: usize, const F: usize, const L: usize> CanUpdateWithGradients
for TransformerDecoder<M, H, F, L>
impl<const M: usize, const H: usize, const F: usize> SaveToNpz
for TransformerDecoderBlock<M, H, F>
{
fn update<G: GradientProvider>(&mut self, grads: &mut G, unused: &mut UnusedTensors) {
self.0.update(grads, unused);
fn write<W: Write + Seek>(&self, pre: &str, w: &mut ZipWriter<W>) -> ZipResult<()> {
self.self_attn.write(&format!("{pre}self_attn."), w)?;
self.norm1.write(&format!("{pre}norm1."), w)?;
self.mh_attn.write(&format!("{pre}mh_attn."), w)?;
self.norm2.write(&format!("{pre}norm2."), w)?;
self.ff.0 .0.write(&format!("{pre}linear1."), w)?;
self.ff.0 .2.write(&format!("{pre}linear2."), w)?;
self.norm3.write(&format!("{pre}norm3."), w)?;
Ok(())
}
}

impl<const M: usize, const H: usize, const F: usize, const L: usize, Tgt, Mem> Module<(Tgt, Mem)>
for TransformerDecoder<M, H, F, L>
where
Mem: Tensor<NoTape = Mem>,
TransformerDecoderBlock<M, H, F>: Module<(Tgt, Mem), Output = Tgt>,
impl<const M: usize, const H: usize, const F: usize> LoadFromNpz
for TransformerDecoderBlock<M, H, F>
{
type Output = Tgt;

fn forward(&self, (mut x, mem): (Tgt, Mem)) -> Self::Output {
for block in self.0.modules.iter() {
x = block.forward((x, mem.duplicate()));
}
x
fn read<R: Read + Seek>(&mut self, pre: &str, r: &mut ZipArchive<R>) -> Result<(), NpzError> {
self.self_attn.read(&format!("{pre}self_attn."), r)?;
self.norm1.read(&format!("{pre}norm1."), r)?;
self.mh_attn.read(&format!("{pre}mh_attn."), r)?;
self.norm2.read(&format!("{pre}norm2."), r)?;
self.ff.0 .0.read(&format!("{pre}linear1."), r)?;
self.ff.0 .2.read(&format!("{pre}linear2."), r)?;
self.norm3.read(&format!("{pre}norm3."), r)?;
Ok(())
}
}

Expand Down
60 changes: 44 additions & 16 deletions src/nn/transformer/encoder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
use crate::prelude::*;
use std::io::{Read, Seek, Write};
use zip::{result::ZipResult, ZipArchive, ZipWriter};

/// **Requires Nightly** A transformer encoder.
///
/// Generics
/// - `MODEL_DIM`: The size of query/key/value tensors. Given to [MultiHeadAttention].
/// - `NUM_HEADS`: The number of heads in [MultiHeadAttention].
/// - `FF_DIM`: The size of the hidden layer in
/// the feedforward network in [TransformerEncoderBlock].
/// - `NUM_LAYERS`: The number of [TransformerEncoderBlock] to use.
/// TODO: Doctests
pub type TransformerEncoder<
const MODEL_DIM: usize,
const NUM_HEADS: usize,
const FF_DIM: usize,
const NUM_LAYERS: usize,
> = Repeated<TransformerEncoderBlock<MODEL_DIM, NUM_HEADS, FF_DIM>, NUM_LAYERS>;

/// **Requires Nightly** A single transformer encoder block
///
Expand All @@ -14,7 +32,7 @@ use crate::prelude::*;
/// )
/// ```
/// TODO: Doctests
#[derive(Debug, Default)]
#[derive(Clone, Debug, Default)]
pub struct TransformerEncoderBlock<
const MODEL_DIM: usize,
const NUM_HEADS: usize,
Expand Down Expand Up @@ -74,21 +92,31 @@ where
}
}

/// **Requires Nightly** A transformer encoder.
///
/// Generics
/// - `MODEL_DIM`: The size of query/key/value tensors. Given to [MultiHeadAttention].
/// - `NUM_HEADS`: The number of heads in [MultiHeadAttention].
/// - `FF_DIM`: The size of the hidden layer in
/// the feedforward network in [TransformerEncoderBlock].
/// - `NUM_LAYERS`: The number of [TransformerEncoderBlock] to use.
/// TODO: Doctests
pub type TransformerEncoder<
const MODEL_DIM: usize,
const NUM_HEADS: usize,
const FF_DIM: usize,
const NUM_LAYERS: usize,
> = Repeated<TransformerEncoderBlock<MODEL_DIM, NUM_HEADS, FF_DIM>, NUM_LAYERS>;
impl<const M: usize, const H: usize, const F: usize> SaveToNpz
for TransformerEncoderBlock<M, H, F>
{
fn write<W: Write + Seek>(&self, pre: &str, w: &mut ZipWriter<W>) -> ZipResult<()> {
self.self_attn.write(&format!("{pre}self_attn."), w)?;
self.norm1.write(&format!("{pre}norm1."), w)?;
self.norm2.write(&format!("{pre}norm2."), w)?;
self.ff.0 .0.write(&format!("{pre}linear1."), w)?;
self.ff.0 .2.write(&format!("{pre}linear2."), w)?;
Ok(())
}
}

impl<const M: usize, const H: usize, const F: usize> LoadFromNpz
for TransformerEncoderBlock<M, H, F>
{
fn read<R: Read + Seek>(&mut self, pre: &str, r: &mut ZipArchive<R>) -> Result<(), NpzError> {
self.self_attn.read(&format!("{pre}self_attn."), r)?;
self.norm1.read(&format!("{pre}norm1."), r)?;
self.norm2.read(&format!("{pre}norm2."), r)?;
self.ff.0 .0.read(&format!("{pre}linear1."), r)?;
self.ff.0 .2.read(&format!("{pre}linear2."), r)?;
Ok(())
}
}

#[cfg(test)]
mod tests {
Expand Down
45 changes: 43 additions & 2 deletions src/nn/transformer/mha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,9 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::assert_close;
use rand::{rngs::StdRng, SeedableRng};
use crate::{nn::tests::SimpleGradients, tests::assert_close};
use rand::{rngs::StdRng, thread_rng, SeedableRng};
use tempfile::NamedTempFile;

#[test]
fn test_mha_unbatched() {
Expand Down Expand Up @@ -296,4 +297,44 @@ mod tests {
],
);
}

#[test]
fn test_backward_updates_all() {
let mut rng = thread_rng();

let mut mha: MultiHeadAttention<12, 4> = Default::default();
mha.reset_params(&mut rng);

let q: Tensor3D<2, 3, 12> = TensorCreator::randn(&mut rng);
let k: Tensor3D<2, 4, 12> = TensorCreator::randn(&mut rng);
let v: Tensor3D<2, 4, 12> = TensorCreator::randn(&mut rng);
let y: Tensor3D<2, 3, 12, _> = mha.forward((q.trace(), k, v));

let mut g = SimpleGradients(y.mean().backward());
let mut unused = Default::default();
mha.update(&mut g, &mut unused);
assert!(unused.is_empty());
}

#[test]
fn test_save_and_load() {
let mut rng = thread_rng();

let mut saved: MultiHeadAttention<12, 4> = Default::default();
saved.reset_params(&mut rng);

let file = NamedTempFile::new().expect("failed to create tempfile");
saved.save(file.path()).expect("");

let mut loaded: MultiHeadAttention<12, 4> = Default::default();
loaded.load(file.path()).expect("");

let q: Tensor3D<2, 3, 12> = TensorCreator::randn(&mut rng);
let k: Tensor3D<2, 4, 12> = TensorCreator::randn(&mut rng);
let v: Tensor3D<2, 4, 12> = TensorCreator::randn(&mut rng);
let y1: Tensor3D<2, 3, 12, _> = saved.forward((q.clone(), k.clone(), v.clone()));
let y2: Tensor3D<2, 3, 12, _> = loaded.forward((q.clone(), k.clone(), v.clone()));

assert_eq!(y1.data(), y2.data());
}
}
49 changes: 47 additions & 2 deletions src/nn/transformer/transformer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::prelude::*;
use std::io::{Read, Seek, Write};
use zip::{result::ZipResult, ZipArchive, ZipWriter};

/// **Requires Nightly** Transformer architecture as described in
/// [Attention is all you need](https://arxiv.org/abs/1706.03762).
Expand All @@ -23,7 +25,7 @@ use crate::prelude::*;
/// batch_first=True,
/// )
/// ```
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone)]
pub struct Transformer<
const MODEL_DIM: usize,
const NUM_HEADS: usize,
Expand Down Expand Up @@ -72,11 +74,32 @@ where
}
}

impl<const M: usize, const H: usize, const E: usize, const D: usize, const F: usize> SaveToNpz
for Transformer<M, H, E, D, F>
{
fn write<W: Write + Seek>(&self, pre: &str, w: &mut ZipWriter<W>) -> ZipResult<()> {
self.encoder.write(&format!("{pre}encoder."), w)?;
self.decoder.write(&format!("{pre}decoder."), w)?;
Ok(())
}
}

impl<const M: usize, const H: usize, const E: usize, const D: usize, const F: usize> LoadFromNpz
for Transformer<M, H, E, D, F>
{
fn read<R: Read + Seek>(&mut self, pre: &str, r: &mut ZipArchive<R>) -> Result<(), NpzError> {
self.encoder.read(&format!("{pre}encoder."), r)?;
self.decoder.read(&format!("{pre}decoder."), r)?;
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::nn::tests::SimpleGradients;
use rand::{rngs::StdRng, SeedableRng};
use rand::{rngs::StdRng, thread_rng, SeedableRng};
use tempfile::NamedTempFile;

#[test]
fn test_forward() {
Expand Down Expand Up @@ -112,4 +135,26 @@ mod tests {

assert!(unused.is_empty());
}

#[test]
fn test_save_load() {
let mut rng = thread_rng();

let mut saved: Transformer<16, 4, 3, 4, 8> = Default::default();
saved.reset_params(&mut rng);

let file = NamedTempFile::new().expect("failed to create tempfile");
saved.save(file.path()).expect("");

let mut loaded: Transformer<16, 4, 3, 4, 8> = Default::default();
loaded.load(file.path()).expect("");

let src: Tensor3D<4, 12, 16> = TensorCreator::randn(&mut rng);
let tgt: Tensor3D<4, 6, 16> = TensorCreator::randn(&mut rng);

let y1 = saved.forward((src.clone(), tgt.clone()));
let y2 = loaded.forward((src.clone(), tgt.clone()));

assert_eq!(y1.data(), y2.data());
}
}
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载