+
Skip to content

Conversation

coreylowman
Copy link
Owner

@coreylowman coreylowman commented Aug 30, 2022

Closes #158
Closes #122
Closes #170
Closes #171
Closes #172

This turned into a way bigger refactor than I was expecting at first. Summary of changes:

  1. Repeated now uses Vec instead of array (Repeated should use vec instead of array for smaller memory footprint #172)
  2. MultiHeadAttention changes
    1. Removed N param
    2. Make K & V default to M
    3. Rename M -> EMBED_DIM, H -> NUM_HEADS, K -> K_DIM, V -> V_DIM
    4. impl SaveToNpz and LoadFromNpz
    5. Unify the same seq length impls and the diff seq length impls into one. Now there are only 2 impls, 1 for batched and 1 for unbatched. Both of the impls support different seq lengths, which by def supports the same seq length. (Can MHA Module impls be combined? #170)
    6. Both Module impls now take a 3-tuple of q/k/v (Can MHA Module impls be combined? #170)
    7. Add appropriate calls to permute_axes (MultiHeadAttention needs to permute before reshape & after final matul #158)
    8. Fix scalar value (MultiHeadAttention dividing by wrong scalar value #171)
  3. TransformerEncoderBlock/TransformerDecoderBlock
    1. Both of these now impl Module generically, so there is only 1 impl module instead of 2
    2. Updated to use new generic params
    3. Updated to use 3-tuple input to MHA
  4. TransformerEncoder/TransformerDecoder
    1. Both of these now impl Module generically, so there is only 1 impl module instead of 2
    2. Updated to use new generic params
    3. Updated to use 3-tuple input to MHA
  5. Testing changes (for MHA, TransformerEncoderBlock, and TransformerDecoderBlock)
    1. They all randomly initialize all of their parameters, and their input values
    2. Only the expected output is specified. And this value is generated from passing the exact same random parameter values through the corresponding pytorch model

@coreylowman
Copy link
Owner Author

FYA @jafioti. the addition of permute_axes made transformers implementation look a lot closer to other ones I've seen. and some additions of Module<> into where clauses helped reduce number of impls

@coreylowman coreylowman merged commit b5208ce into main Aug 31, 2022
@coreylowman coreylowman deleted the 158-transformers-permute branch August 31, 2022 12:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

1 participant

点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载