Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fixed bug with catboost and groups #1383

Merged

Conversation

dannycg1996
Copy link
Collaborator

@dannycg1996 dannycg1996 commented Dec 4, 2024

Why are these changes needed?

Currently an error is raised when we attempt to use a group-based split (such as GroupKFold or StratifiedGroupKFold) with CatBoost. This issue is caused by the fit() method on CatBoostEstimator, in particular this part:

            model.fit(
                X_tr,
                y_tr,
                cat_features=cat_features,
                eval_set=eval_set,
                callbacks=CatBoostEstimator._callbacks(
                    start_time, deadline, free_mem_ratio if use_best_model else None
                ),
                **kwargs,
            )

When groups is not None, kwargs contains the groups. CatBoost fit methods don't accept groups, which causes the issue.

As I understand it, groups in FLAML (and generally) should be used when splitting data, not when fitting data. Basically, CatBoost does not need groups to be passed in, and likely never will.

I've been through the other FLAML estimators, and they all seem to remove groups from kwargs before calling model.fit. For example, in the _fit() method on the BaseEstimator class defined on flaml/automl/model.py, we have the following code:

        if "groups" in kwargs:
            kwargs = kwargs.copy()
            groups = kwargs.pop("groups")
            if self._task == "rank":
                kwargs["group"] = group_counts(groups)

This code removes groups from the kwargs dictionary, before model.fit() is called.

I've mimicked this 'removal of groups from kwargs' in my code. I've also updated the two tests of group-related splits to include CatBoost as an estimator. Before my changes, the addition of catboost causes these tests to fail, but after removing the groups from kwargs, they pass :)

I welcome any feedback!

Related issue number

Closes #304

Checks

Copy link
Collaborator

@thinkall thinkall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a trade-off in user experience. The current code raises an error when CatBoost is used with data containing groups, whereas your PR proposes to silently ignore the group information.

@dannycg1996
Copy link
Collaborator Author

dannycg1996 commented Dec 16, 2024

This is a trade-off in user experience. The current code raises an error when CatBoost is used with data containing groups, whereas your PR proposes to silently ignore the group information.

Hi @thinkall,

I'm confused by this. I would argue that my PR doesn't propose to silently ignore the group information - it is simply proposing to not pass groups into fit(). This PR doesn't seek to treat CatBoost differently to any other estimator in FLAML - it's simply stripping groups out of the kwargs before passing the kwargs to model.fit(). This is the exact same step which every other FLAML estimator I've tested does. I see no reason why we need to pass a groups argument into the .fit() method for CatBoost and CatBoost only.

To emphasis my point, I first created an example script which reproduces the reported CatBoost error:

from flaml import AutoML
from sklearn import datasets
import numpy as np

dic_data = datasets.load_iris(as_frame=True)  # numpy arrays
iris_data = dic_data["frame"]  # pandas dataframe data + target
rng = np.random.default_rng(42)
iris_data["cluster"] = rng.integers(
    low=0, high=5, size=iris_data.shape[0]
)
print(iris_data["cluster"])
automl = AutoML()
automl_settings = {
    "max_iter":2,
    "metric": 'accuracy',
    "task": 'classification',
    "log_file_name": "catboost_error.log",
    "log_type": "all",
    "estimator_list": ['catboost'],
    "eval_method": "cv",
    "split_type":"group",
    "groups": iris_data['cluster']
}
x_train = iris_data[["sepal length (cm)","sepal width (cm)", "petal length (cm)","petal width (cm)"]].to_numpy()
y_train = iris_data['target']

The error (as expected):
image

I then re-ran the script with "estimator_list": ['lgbm']. This initially runs without errors, as expected. However, if I go into BaseEstimator._fit(), on `flaml/automl/model.py and comment out this the kwarg-popping code mentioned above:

        if "groups" in kwargs:
            kwargs = kwargs.copy()
            groups = kwargs.pop("groups")
            if self._task == "rank":
                kwargs["group"] = group_counts(groups)

Then rerunning the AutoML code gives this error:

image

Likewise if I run the same code with RandomForestEstimator as my only estimator:
image

Basically as I understand it, CatBoostEstimator is the only FLAML estimator (that I use at least) which does not at any point call BaseEstimator._fit(). Therefore CatBoostEstimator is the only FLAML model which doesn't pop the groups keyword out of the kwargs before calling model.fit(). I've investigated this pretty thoroughly now, and I'm fairly confident that no other estimator tries to pass groups to model.fit(). I'm not sure why CatBoostEstimator should be the exception.

Apologies if I'm missing something obvious.
Anyway thanks for reading, I hope that makes sense - please let me know your thoughts!

Copy link
Collaborator

@thinkall thinkall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much, @dannycg1996 !

@thinkall thinkall merged commit 42d1dcf into microsoft:main Dec 17, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

CatBoost Fails with Keyword 'groups'
2 participants