Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled feature_importances_ for our ForestDML and ForestDRLearner estimators #306

Merged
merged 17 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion econml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def __init__(self,
random_state=random_state)


class ForestDMLCateEstimator(NonParamDMLCateEstimator, ForestModelFinalCateEstimatorMixin):
class ForestDMLCateEstimator(ForestModelFinalCateEstimatorMixin, NonParamDMLCateEstimator):
""" Instance of NonParamDMLCateEstimator with a
:class:`~econml.sklearn_extensions.ensemble.SubsampledHonestForest`
as a final model, so as to enable non-parametric inference.
Expand Down
2 changes: 1 addition & 1 deletion econml/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def fitted_models_final(self):
return super().model_final.models_cate


class ForestDRLearner(DRLearner, ForestModelFinalCateEstimatorDiscreteMixin):
class ForestDRLearner(ForestModelFinalCateEstimatorDiscreteMixin, DRLearner):
""" Instance of DRLearner with a :class:`~econml.sklearn_extensions.ensemble.SubsampledHonestForest`
as a final model, so as to enable non-parametric inference.

Expand Down
48 changes: 44 additions & 4 deletions econml/sklearn_extensions/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

""" Subsampled honest forest extension to scikit-learn's forest methods.
""" Subsampled honest forest extension to scikit-learn's forest methods. Contains pieces of code from
scikit-learn's random forest implementation.

TODO. Currently the node.impurity entry of every node is the impurity based on the split half-sample and not
the estimation half-sample. This slightly affects the feature_importance_ calcualtion as the impurity is based
on the split half-sample, but the weighted_n_node_samples is based on the estimation half-sample. Identify
whether there is a fast way to also re-calculate impurities, even if it means restricting only to the MSE
criterion.
"""

import numpy as np
Expand Down Expand Up @@ -538,9 +545,7 @@ def fit(self, X, y, sample_weight=None, sample_var=None):
t, self, X, y, sample_weight, s_inds[i], i, len(trees),
verbose=self.verbose)
for i, t in enumerate(trees))
trees = [t[0] for t in res]
numerators = [t[1] for t in res]
denominators = [t[2] for t in res]
trees, numerators, denominators = zip(*res)
# Collect newly grown trees
self.estimators_.extend(trees)
self.numerators_.extend(numerators)
Expand Down Expand Up @@ -756,3 +761,38 @@ def predict_interval(self, X, alpha=.1, normal=True):
lower_pred = scipy.stats.norm.ppf(
alpha / 2, loc=y_point_pred, scale=pred_stderr)
return lower_pred, upper_pred

@property
def feature_importances_(self):
"""
The impurity-based feature importances.

The higher, the more important the feature.
The importance of a feature is computed as the (normalized)
total reduction of the criterion brought by that feature. It is also
known as the Gini importance.

Returns
-------
feature_importances_ : ndarray of shape (n_features,)
The values of this array sum to 1, unless all trees are single node
trees consisting of only the root node, in which case it will be an
array of zeros.
"""
check_is_fitted(self)

def unnormalized_importances(tree):
return tree.tree_.compute_feature_importances(normalize=False)

all_importances = Parallel(n_jobs=self.n_jobs,
**_joblib_parallel_args(prefer='threads'))(
delayed(unnormalized_importances)(tree)
for tree in self.estimators_ if tree.tree_.node_count > 1)

if not all_importances:
return np.zeros(self.n_features_, dtype=np.float64)
vsyrgkanis marked this conversation as resolved.
Show resolved Hide resolved

all_importances = np.mean(all_importances,
axis=0, dtype=np.float64)
all_importances = np.clip(all_importances, 0, np.inf)
return all_importances / np.sum(all_importances)