-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-1333] Estimator and Fit API #14629
Conversation
Would it be possible to make these tests part of the training test suite instead of introducing new jobs? |
@marcoabreu are you referring to training test suite here? we want to only test this on nightly. Please point to the correct test suite if I m wrong. Thanks! |
c2e2f80
to
d76234b
Compare
Suggest to move it to contrib as there is some feedback from Mu on the dev@. We could also gather feedback from the users to see what other changes are required. Could you please break all the backlog items into Jira tasks and paste the master ticket to this PR ? Any contributor interested to further contribute to this could pick up those tasks. |
@nswamy done, all JIRA tickets has detailed description and reference to the feedback (either from cwiki or dev list discussion) |
import numpy as np | ||
|
||
|
||
class EventHandler(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please break this down into different event classes and use the same approach as gluon's forward hook through weakref. this has the benefit of:
- people can mix what they need into a unified handler through multi-inheritence, without the unnecessary
pass
calls. - a handler can be detached at will.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i added a method to check if a handlers has implemented train_begin
ect to avoid unnecessary pass
calls. The time it takes should be the same as using multi-inheritence and do a bunch of isinstance()
at the beginning. The benefit is user can override any event call without inherit multiple class.
Still looking into how to use forward hook to register different input args
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution. see above. more comments to come.
Thanks for your contributions @roywei. @mxnet-label-bot Add [Gluon, pr-awaiting-review] |
Thanks @szha for the feedback! I m noting the summary of offlien discussion here:
|
I am not sure I understand the concerns, What is the problem with 1) ? For 2), the user can create custom event Handler taking objects whatever it needs to keep track of.
Given that the fit API is targeted at novice users, I think 3) is going to make it unnecessarily cumbersome. What is the benefit of using the Forward Hook approach? |
wouldn't it be easier for the user to write the training loop if they want more control instead of having the loop split 6 or more methods or hooks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blocking the PR until my questions are answered.
I would like to understand why and how Sheng's proposal is better than the current design which was discussed offline and surfaced on dev@ months ago.
Last minute requests to fundamentally change the design should have very strong reasons.
@nswamy @szha I have addressed the concerns on callback in this doc: https://cwiki.apache.org/confluence/display/MXNET/Callback+Design+for+Fit+Loop Please help take a look, thanks! |
@eric-haibin-lin @pinaraws @szha Thanks for the review, address comments here: https://github.com/apache/incubator-mxnet/pull/14885/files |
My comments are addressed. Great work!! |
* base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests
#14464) * Fixed issue where the estimator was printing beyond the dataset size for the last batch * Added comments * Nudge to CI
…API (#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric
* add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers
* Added RNN integration test for fit() API * Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports * CPU test doesn't require nvidiadocker container * Modified the structure by removing the redundant code
* added cnn intg tests for fit api * updated cnn intg tests * added functions for nightly test * updated runtime_function * updated intg tests * updated init, datapath, refs * added validation data * update cpu test * refactor code * updated context
…upport for Gluon fit() API (#14587) * Retrieve Batch size and Logging verbose support for Gluon fit() API * NIT changes * Addressed review comments: shifted the batch size code to a separate method, sentence correction * Modified unittest * removed redundant parameter * Resolve CI test failure * only support DataLoader for now, future PRs will include DataIter to DataLoader converter * Get the number of samples from shape attribute instead of length due to low space complexity * Simplified batch size retrieval code * removed batch_size parameter from fit() method and fixed the tests * Verbose exception handling * Assigning constant to a verbose * Modified exception message * Resolved undefined class reference * Addressed review comments: Modified verbose level names, docs, variable names * Update estimator.py
* improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests
* move to nightly for binaries * update default handler * fix pylint * trigger ci * trigger ci
* address comments * add comment * check available context * fix bug * change cpu check
* address comments * update checkpoint * test symbol save * address comments * add resume * update doc and resume checkpoint * update docs * trigger ci * trigger ci
@szha @eric-haibin-lin CI finally passed, validation/miscellaneous job status returned not correctly (passed instead of pending). Could you help merge if looks good? |
Nice work! Great job upholding the quality even at the cost of several iterations. Well done! |
This reverts commit 9f451fb.
* [MXNet-1334][Fit API]base class for estimator and eventhandler (apache#14346) * base class for estimator and eventhandler * add license * add event handlers * fix pylint * improve arg check * fix pylint * add unit tests * Fixed issue where the estimator was printing beyond the dataset size … (apache#14464) * Fixed issue where the estimator was printing beyond the dataset size for the last batch * Added comments * Nudge to CI * [MXNet-1349][Fit API]Add validation support and unit tests for fit() API (apache#14442) * added estimator unittests * add more tests for estimator * added validation logic * added error handlers, unittests * improve val stats * fix pylint * fix pylint * update unit test * fix tests * fix tests * updated metrics, val logic * trigger ci * trigger ci * update metric, batch_fn error handler * update context logic, add default metric * [MXNet-1340][Fit API]Update train stats (apache#14494) * add train history * update history * update test * avoid calling empty methods * remove train history object * fix pylint * add unit test * fix test * update categorize handlers * [MXNet-1375][Fit API]Added RNN integration test for fit() API (apache#14547) * Added RNN integration test for fit() API * Addressed review comments: change in JenkinFile, tmp directory, ctx with condense if/else, renamed imports * CPU test doesn't require nvidiadocker container * Modified the structure by removing the redundant code * [MXNet-1343][Fit API]Add CNN integration test for fit() API (apache#14405) * added cnn intg tests for fit api * updated cnn intg tests * added functions for nightly test * updated runtime_function * updated intg tests * updated init, datapath, refs * added validation data * update cpu test * refactor code * updated context * [MXNET-1344, 1346][FIT API] Retrieve Batch size and Logging verbose support for Gluon fit() API (apache#14587) * Retrieve Batch size and Logging verbose support for Gluon fit() API * NIT changes * Addressed review comments: shifted the batch size code to a separate method, sentence correction * Modified unittest * removed redundant parameter * Resolve CI test failure * only support DataLoader for now, future PRs will include DataIter to DataLoader converter * Get the number of samples from shape attribute instead of length due to low space complexity * Simplified batch size retrieval code * removed batch_size parameter from fit() method and fixed the tests * Verbose exception handling * Assigning constant to a verbose * Modified exception message * Resolved undefined class reference * Addressed review comments: Modified verbose level names, docs, variable names * Update estimator.py * move estimator to contrib (apache#14633) * move to gluon contrib (apache#14635) * [Fit API] improve event handlers (apache#14685) * improve event handlers * update tests * passing weakref of estimator * fix unit test * fix test * fix pylint * fix test * fix pylint * move default metric logic * combine nightly tests * [MXNET-1396][Fit-API] Update default handler logic (apache#14765) * move to nightly for binaries * update default handler * fix pylint * trigger ci * trigger ci * [Fit API] update estimator (apache#14849) * address comments * add comment * check available context * fix bug * change cpu check * [Fit-API] Adress PR comments (apache#14885) * address comments * update checkpoint * test symbol save * address comments * add resume * update doc and resume checkpoint * update docs * trigger ci * trigger ci
) This reverts commit 9f451fb.
Description
This PR introduce an Estimator class in contrib with easy fit method to help beginners with model training.
It's been developed on a branch, and we hope to merge it to contrib and get feedback for first iteration.
Design: https://cwiki.apache.org/confluence/display/MXNET/Gluon+Fit+API+-+Tech+Design
JIRA epics: https://issues.apache.org/jira/browse/MXNET-1333
Dev list discussion: https://lists.apache.org/thread.html/13e3dee0fc9dd8e45b6616f97d282096a1ee67cde78a93dada295577@%3Cdev.mxnet.apache.org%3E
Feedbacks: currently all feedbacks are captured in cwiki comment section. We have created JIRA issues for each feedback and will continue to work on it
Follow up PRs:
We currently have the following PRs to address feedback, will create more and track using JIRA issue
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments