+
Skip to content

Conversation

ViliamVadocz
Copy link
Contributor

Works towards #306.
Depends on coreylowman/cudarc#60.

I am creating this as a draft to get feedback on the structure before implementing other operations (although this should be quite straightforward now).

Let me know what you think.
I will add documentation once implementation is finalized.

@ViliamVadocz
Copy link
Contributor Author

Some things to consider and choices I would like feedback on:

  • Should Unit require PartialEq? At the moment I enforce that at the implementation for EqKernelOp as an additional requirement.
  • I used a generic to specify the type of operation. I know that other kernel operations (such as UnaryKernel and BinaryKernel) pass in an empty struct. Should this also be done for comparison operations?
  • I placed the comparison operations in a sub-directory cmp I did this so that I could separate the comparison-specific kernel support code from the code for other operations which require Dtype. I think this is a good solution but perhaps a different structure is preferred?

@nkoppel
Copy link
Contributor

nkoppel commented Jan 21, 2023

I used a generic to specify the type of operation. I know that other kernel operations (such as UnaryKernel and BinaryKernel) pass in an empty struct. Should this also be done for comparison operations?

It should be fine not to. I added these struct parameters to the binary operations in #346 only because of huber_error's delta parameter, and comparison operations should not have this kind of parameter.

You should probably implement unary versions of these operations, because this is consistent with the rest of the api and will probably be the more common use case.

While tensors of booleans are currently useless with the rest of the api, I think it is best to keep the output as tensors of booleans and to create operations for them. These could include inversion (!), and (&), or (|), xor (^), and masking operations for use with other tensors.

@ViliamVadocz
Copy link
Contributor Author

You should probably implement unary versions of these operations, because this is consistent with the rest of the api and will probably be the more common use case.

I don't understand how comparison could be unary. Did you mean something like comparing to a scalar? Or perhaps you meant that the comparison functions should be methods on the tensor type? (They are.)

@nkoppel
Copy link
Contributor

nkoppel commented Jan 22, 2023

Sorry for the confusion, I meant that you should add operations that act like ScalarAdd, comparing each value within the tensor with a single number, like <tensor>.eq(0.0).

Copy link
Owner

@coreylowman coreylowman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a great start! I like the addition of CmpKernel/CmpOpCpuKernel/CmpOpCudaKernel for reducing code. Since you added the above and the actual kernel implementations look pretty small, I think combining them into the same file so the structure is the same as the other ops should be good:

- tensor_ops/
    - cmp/
        - mod.rs
        - cuda_kernels.rs
        - cpu_kernels.rs
        - cmp.cu

where they would contain the following:

  1. mod.rs - all the tensor methods/functions & tests
  2. cpu_kernels.rs
    1. The trait definition for CmpOpCpuKernel
    2. The trait impl for impl<Op: CmpOpCpuKernel<E>, E: Unit> CmpKernel<Op, E> for Cpu {
    3. all the impl CmpOpCpuKernel<E> for EqKernelOp for all the ops (Eq/Lt, etc)
  3. cuda_kernels.rs - similar to cpu kernels but for the cuda stuff
  4. cmp.cu all the forward functions for each op

@ViliamVadocz ViliamVadocz marked this pull request as ready for review January 30, 2023 02:35
@ViliamVadocz
Copy link
Contributor Author

I added the rest of the comparison operations as well as scalar operations. I was unsure about what to name the scalar comparisons, but I went with scalar_op where op is replaced by one of eq, ne, ge, etc.

I was wondering whether I should also implement PartialEq and PartialCmp for Tensor so that it would be possible to use the actual operators like <, etc.

@nkoppel
Copy link
Contributor

nkoppel commented Jan 30, 2023

The <, ==, and other comparison operations can, unfortunately, only output single booleans. However, I think it's a good idea to implement PartialEq and PartialOrd for tensors. The cuda implementation for PartialEq should boil down to a reduction, but the kernel for ParitalOrd will need to prioritize earlier elements in the tensors, which will require something like a custom atomic operation.

@coreylowman coreylowman self-requested a review January 31, 2023 14:05
Copy link
Owner

@coreylowman coreylowman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Structure is perfect, and a lot of good stuff in here. Only a couple things to add and then its good to go 🚀

@ViliamVadocz
Copy link
Contributor Author

ViliamVadocz commented Feb 3, 2023

Please note that the CUDA code relies on bool implementing ValidAsZeroBits. This was already merged in cudarc (coreylowman/cudarc#60), but a new version of the crate has not been published yet.

If you try compiling this branch with test-cuda you will get errors unless you also change the cudarc dependency to use the main branch from git. Probably best to wait with merging this until a new version of cudarc is published.

}

#[test]
#[should_panic]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test 😀

Copy link
Owner

@coreylowman coreylowman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will merge once cudarc releases support for bools

@coreylowman coreylowman merged commit 79d581f into coreylowman:main Feb 9, 2023
@coreylowman
Copy link
Owner

Thanks for the contribution!

@ViliamVadocz ViliamVadocz deleted the comparison-ops branch February 9, 2023 16:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载