This repository contains the PyTorch implementation of SPTMod used for the DAFx25 pre-print paper Empirical Results for Adjusting Truncated Backpropagation Through Time while Training Neural Audio Effects authored by Y. Bourdin, P. Legrand and F. Roche.
The paper is available here.
Go to https://ybourdin.github.io/sptmod/ to listen to sound examples or take a look at the listening test.
Figure 1. SPTMod, on the right, consists in two paths called the modulation and audio paths. The modulation path is a series of modulation blocks which compute modulation tensors (the
Figure 2a. Composition of the ModBlock
Figure 2b. Composition of an SPN block
Say we want a cumulative TBPTT length
Because we do not use padding for the first sub-sequence, the corresponding input length of this sub-sequence is Model.set_target_length()
method.
- A list of sequences of length
$L_0^{in} + L_1 + ... + L_{N-1}$ must be extracted from the training dataset using a sliding window. - Then, before each training epoch:
- The list of sequences is shuffled
- The sequences are further split into sub-sequences
- We iterate through the sub-sequences in the way described in the diagram below:
Figure 3. Diagram of the lengths of intermediary tensors for consecutive sequence batches in our TBPTT-based approach, with 3 sub-sequences. In the first iteration, no padding is applied, so the input length incorporates the number of samples required by the temporal operations. A large additional length is used by the SPN to initialize the states of recurrent layers. In subsequent iterations, states and caches are retained, but their gradients are detached from the computational graph.
In the code, at every iteration (an iteration corresponding to one backpropagation pass, i.e. one sub-sequence), we need to:
- Check if the number of the current iteration is divisible by
$N$ - If yes, then the first sub-sequence of a group is to be processed
- Compute the index intervals to slice the tensors given to different parts of the model, here the SPN, modulation path and audio path; and compute the cropping sizes of the cropping layers. This is the goal of the
Model.set_target_length()
method. - Reset the model states with
Model.reset_states()
. - Call
Model.forward()
withpaddingmode = CachedPadding1d.NoPadding
anduse_spn = True
.
- Compute the index intervals to slice the tensors given to different parts of the model, here the SPN, modulation path and audio path; and compute the cropping sizes of the cropping layers. This is the goal of the
- If not,
- Detach the model states with
Model.detach_states()
. - Call
Model.forward()
withpaddingmode = CachedPadding1d.CachedPadding
anduse_spn = False
.
- Detach the model states with
The architecture of SPTMod comprises two paths containing blocks that consume some time samples (their output lengths are shorter than their input lengths) and that imply conditions over the lengths of some intermediary tensors, as represented on this diagram:
-
$t_m(n)$ , resp.$t_a(n)$ , is the absolute time index of the start of the sequence received by the$(n+1)^{\text{th}}$ modulation block, resp. audio block. -
$L$ is the target output length. -
$t_a(N)$ is the absolute time index of the start of the resulting sequence. We set$t_a(N)=0$ . The absolute time index of its end is thus$L$ . -
$P$ is the pooling size of the modulation blocks. - The condition (A) is that the tensor length before a pooling layer, i.e.
$L - (t_m(n) + \sigma_m(n))$ , must be a multiple of$P$ . - The condition (B) is that the lengths of the inputs to an add/multiply operator are equal.
-
$\sigma_m(n)$ ,$P$ and$\sigma_a(n)$ are samples consumed by operations such as convolutions or pooling-upsampling layers. -
$c_{mm}(n)$ ,$c_{ma}(n)$ ,$c_a(n)$ are cropping sizes added to satisfy conditions (A) and (B), they must be non-negative.
The relationship between consecutive time indices in the modulation path is:
Applying a telescoping sum from
From this, we can express
Condition (A):
We can then express
To find an expression for
Thus,
This equation shows that
For
The relationship between consecutive time indices in the audio path is:
Applying a telescoping sum from
Thus:
Condition (B):
Substitute the expression for
Next, substitute
From this, we get the expression for
For
Assumption on
With
And the condition for
To ensure both
The compensation terms are then calculated (iteratively for
The initial time indices are:
(Note that