-
Notifications
You must be signed in to change notification settings - Fork 517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Fixed bug with catboost and groups #1383
fix: Fixed bug with catboost and groups #1383
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a trade-off in user experience. The current code raises an error when CatBoost is used with data containing groups, whereas your PR proposes to silently ignore the group information.
Hi @thinkall, I'm confused by this. I would argue that my PR doesn't propose to silently ignore the group information - it is simply proposing to not pass groups into To emphasis my point, I first created an example script which reproduces the reported CatBoost error:
I then re-ran the script with
Then rerunning the AutoML code gives this error: Likewise if I run the same code with Basically as I understand it, Apologies if I'm missing something obvious. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much, @dannycg1996 !
Why are these changes needed?
Currently an error is raised when we attempt to use a group-based split (such as
GroupKFold
orStratifiedGroupKFold
) with CatBoost. This issue is caused by thefit()
method onCatBoostEstimator
, in particular this part:When
groups
is not None, kwargs contains the groups. CatBoost fit methods don't accept groups, which causes the issue.As I understand it, groups in FLAML (and generally) should be used when splitting data, not when fitting data. Basically, CatBoost does not need groups to be passed in, and likely never will.
I've been through the other FLAML estimators, and they all seem to remove groups from kwargs before calling
model.fit
. For example, in the_fit()
method on theBaseEstimator
class defined onflaml/automl/model.py
, we have the following code:This code removes
groups
from the kwargs dictionary, beforemodel.fit()
is called.I've mimicked this 'removal of groups from kwargs' in my code. I've also updated the two tests of group-related splits to include CatBoost as an estimator. Before my changes, the addition of catboost causes these tests to fail, but after removing the groups from kwargs, they pass :)
I welcome any feedback!
Related issue number
Closes #304
Checks