+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use std::collections::HashMap;
use std::{boxed::Box, vec::Vec};

use crate::arrays::HasArrayType;
use crate::arrays::{HasArrayData, HasArrayType};
use crate::devices::{AllocateZeros, HasDevice};
use crate::unique_id::{HasUniqueId, UniqueId};

Expand Down Expand Up @@ -264,7 +264,7 @@ pub trait GradientProvider {
/// based on the associated data!
fn gradient<P>(&mut self, p: &P) -> Option<Box<P::Array>>
where
P: HasUniqueId + HasArrayType<Dtype = f32> + HasDevice;
P: HasUniqueId + HasArrayType<Dtype = f32> + HasDevice + HasArrayData;
}

/// Represents something that can be updated with [GradientProvider].
Expand Down
3 changes: 2 additions & 1 deletion src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ mod npz_impls;

#[cfg(test)]
mod tests {
use crate::arrays::{HasArrayData, HasArrayType};
use crate::gradients::{GradientProvider, Gradients};
use crate::unique_id::HasUniqueId;
use std::boxed::Box;
Expand All @@ -136,7 +137,7 @@ mod tests {
impl GradientProvider for SimpleGradients {
fn gradient<P>(&mut self, p: &P) -> Option<Box<P::Array>>
where
P: HasUniqueId + crate::arrays::HasArrayType<Dtype = f32> + crate::devices::HasDevice,
P: HasUniqueId + HasArrayType<Dtype = f32> + crate::devices::HasDevice + HasArrayData,
{
self.0.remove(p)
}
Expand Down
2 changes: 1 addition & 1 deletion src/optim/adam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl<M> Adam<M> {
impl<M> GradientProvider for Adam<M> {
fn gradient<P>(&mut self, p: &P) -> Option<Box<P::Array>>
where
P: HasUniqueId + HasArrayType<Dtype = f32> + HasDevice,
P: HasUniqueId + HasArrayType<Dtype = f32> + HasDevice + HasArrayData,
{
let mut g_t = self.gradients.remove(p)?;
let m_t = self.moment1.mut_gradient(p);
Expand Down
2 changes: 1 addition & 1 deletion src/optim/rmsprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl<M> RMSprop<M> {
impl<M> GradientProvider for RMSprop<M> {
fn gradient<P>(&mut self, p: &P) -> Option<Box<P::Array>>
where
P: HasUniqueId + HasArrayType<Dtype = f32> + HasDevice,
P: HasUniqueId + HasArrayType<Dtype = f32> + HasDevice + HasArrayData,
{
let mut g_t = self.gradients.remove(p)?;

Expand Down
2 changes: 1 addition & 1 deletion src/optim/sgd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl<M> Sgd<M> {
impl<M> GradientProvider for Sgd<M> {
fn gradient<P>(&mut self, p: &P) -> Option<Box<P::Array>>
where
P: HasUniqueId + HasArrayType<Dtype = f32> + HasDevice,
P: HasUniqueId + HasArrayType<Dtype = f32> + HasDevice + HasArrayData,
{
let mut g_t = self.gradients.remove(p)?;
match self.cfg.momentum {
Expand Down
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载