-
Notifications
You must be signed in to change notification settings - Fork 6.8k
softmax for fp16 with fp32 accumulator #14098
Conversation
@ptrendx @DickJC123 I'd appreciate your review on this. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log_softmax should output its result in fp32 - it is very easy to overflow there.
I don't think we can change the return type of log_softmax in minor release. |
I will add an optional dtype option for this. |
@mxnet-label-bot add [Operator, pr-awaiting-review] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice change. Are we actually testing that we have more accuracy with an accumulator of higher precission?
@larroy thanks. Accuracy would be tied to a model and a dataset and is thus harder to reason. The more immediate effect is it's harder to overflow when accumulating e^x. |
I intend to change the PR so that GradUseInOut only happens when dtype is specified and different from input type. |
bdc0d6d
to
704b0e7
Compare
LGTM. @ptrendx any other concerns? Would AMP skip the cast if softmax accumulation for fp16 inputs is done in fp32? |
* softmax for fp16 with fp32 accumulator * return AType in kernel * add dtype * kernel * grad use in-out only when dtype override * simplify infer type * address comments
* softmax for fp16 with fp32 accumulator * return AType in kernel * add dtype * kernel * grad use in-out only when dtype override * simplify infer type * address comments
* softmax for fp16 with fp32 accumulator * return AType in kernel * add dtype * kernel * grad use in-out only when dtype override * simplify infer type * address comments
* softmax for fp16 with fp32 accumulator * return AType in kernel * add dtype * kernel * grad use in-out only when dtype override * simplify infer type * address comments
* softmax for fp16 with fp32 accumulator * return AType in kernel * add dtype * kernel * grad use in-out only when dtype override * simplify infer type * address comments
* softmax for fp16 with fp32 accumulator * return AType in kernel * add dtype * kernel * grad use in-out only when dtype override * simplify infer type * address comments
Description
softmax for fp16 with fp32 accumulator
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes