Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into gbt
Browse files Browse the repository at this point in the history
  • Loading branch information
jppgks committed Jun 28, 2022
2 parents 67334c7 + f654e82 commit ca7db50
Show file tree
Hide file tree
Showing 14 changed files with 260 additions and 38 deletions.
2 changes: 1 addition & 1 deletion ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def predict(
skip_save_unprocessed_output=skip_save_unprocessed_output or not self.backend.is_coordinator(),
)
converted_postproc_predictions = convert_predictions(
postproc_predictions, self.model.output_features, return_type=return_type
postproc_predictions, self.model.output_features, return_type=return_type, backend=self.backend
)

if self.backend.is_coordinator():
Expand Down
12 changes: 6 additions & 6 deletions ludwig/data/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================
import os
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -104,17 +104,17 @@ def convert_dict_to_df(predictions: Dict[str, Dict[str, Union[List[Any], torch.T
return pd.DataFrame.from_dict(output)


def convert_predictions(predictions, output_features, return_type="dict"):
def convert_predictions(
predictions, output_features, return_type="dict", backend: Optional["Backend"] = None # noqa: F821
):
convert_fn = get_from_registry(return_type, conversion_registry)
return convert_fn(
predictions,
output_features,
)
return convert_fn(predictions, output_features, backend)


def convert_to_df(
predictions,
output_features,
backend: Optional["Backend"] = None, # noqa: F821
):
return predictions

Expand Down
6 changes: 5 additions & 1 deletion ludwig/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Any, Dict
from typing import Any, Dict, Optional

import numpy as np

from ludwig.utils.dataframe_utils import is_dask_df
from ludwig.utils.types import DataFrame


def convert_to_dict(
predictions: DataFrame,
output_features: Dict[str, Any],
backend: Optional["Backend"] = None, # noqa: F821
):
"""Convert predictions from DataFrame format to a dictionary."""
output = {}
Expand All @@ -18,6 +20,8 @@ def convert_to_dict(
subgroup = key[len(of_name) + 1 :]

values = predictions[key]
if is_dask_df(values, backend):
values = values.compute()
try:
values = np.stack(values.to_numpy())
except ValueError:
Expand Down
Empty file added ludwig/explain/__init__.py
Empty file.
171 changes: 171 additions & 0 deletions ludwig/explain/captum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
from captum.attr import IntegratedGradients
from torch.autograd import Variable

from ludwig.api import LudwigModel
from ludwig.constants import BINARY, CATEGORY, TYPE
from ludwig.data.preprocessing import preprocess_for_prediction
from ludwig.explain.util import get_feature_name, get_pred_col, prepare_data
from ludwig.models.ecd import ECD
from ludwig.utils.torch_utils import get_torch_device

DEVICE = get_torch_device()


class WrapperModule(torch.nn.Module):
"""Model used by the explainer to generate predictions.
Unlike Ludwig's ECD class, this wrapper takes individual args as inputs to the forward function. We derive the order
of these args from the order of the input_feature keys in ECD, which is guaranteed to be consistent (Python
dictionaries are ordered consistently), so we can map back to the input feature dictionary as a second step within
this wrapper.
"""

def __init__(self, model: ECD, target: str):
super().__init__()
self.model = model
self.target = target

def forward(self, *args):
preds = self.predict_from_encoded(*args)
return get_pred_col(preds, self.target).cpu()

def predict_from_encoded(self, *args):
# Add back the dictionary structure so it conforms to ECD format.
encoded_inputs = {}
for k, v in zip(self.model.input_features.keys(), args):
encoded_inputs[k] = {"encoder_output": v.to(DEVICE)}

# Run the combiner and decoder separately since we already encoded the input.
combined_outputs = self.model.combine(encoded_inputs)
outputs = self.model.decode(combined_outputs, None, None)

# At this point we only have the raw logits, but to make explainability work we need the probabilities
# and predictions as well, so derive them.
predictions = {}
for of_name in self.model.output_features:
predictions[of_name] = self.model.output_features[of_name].predictions(outputs, of_name)
return predictions


def get_input_tensors(model: LudwigModel, input_set: pd.DataFrame) -> List[Variable]:
# Convert raw input data into preprocessed tensor data
dataset, _ = preprocess_for_prediction(
model.config,
dataset=input_set,
training_set_metadata=model.training_set_metadata,
data_format="auto",
split="full",
include_outputs=False,
backend=model.backend,
callbacks=model.callbacks,
)

# Convert dataset into a dict of tensors, and split each tensor into batches to control GPU memory usage
inputs = {
name: torch.from_numpy(dataset.dataset[feature.proc_column]).split(model.config["trainer"]["batch_size"])
for name, feature in model.model.input_features.items()
}

# Dict of lists to list of dicts
input_batches = [dict(zip(inputs, t)) for t in zip(*inputs.values())]

# Encode the inputs into embedding space. This is necessary to ensure differentiability. Otherwise, category
# and other features that pass through an embedding will not be explainable via gradient based methods.
output_batches = []
for batch in input_batches:
batch = {k: v.to(DEVICE) for k, v in batch.items()}
output = model.model.encode(batch)

# Extract the output tensor, discarding additional state used for sequence decoding.
output = {k: v["encoder_output"].detach().cpu() for k, v in output.items()}
output_batches.append(output)

# List of dicts to dict of lists
encoded_inputs = {k: torch.cat([d[k] for d in output_batches]) for k in output_batches[0]}

# Wrap the output into a variable so torch will track the gradient.
# TODO(travis): this won't work for text decoders, but we don't support explanations for those yet
data_to_predict = [v for _, v in encoded_inputs.items()]
data_to_predict = [Variable(t, requires_grad=True) for t in data_to_predict]

return data_to_predict


def explain_ig(
model: LudwigModel, inputs_df: pd.DataFrame, sample_df: pd.DataFrame, target: str
) -> Tuple[np.array, List[float], np.array]:
model.model.to(DEVICE)

inputs_df, sample_df, _, target_feature_name = prepare_data(model, inputs_df, sample_df, target)

# Convert input data into embedding tensors from the output of the model encoders.
inputs_encoded = get_input_tensors(model, inputs_df)
sample_encoded = get_input_tensors(model, sample_df)

# For a robust baseline, we take the mean of all embeddings in the sample from the training data.
# TODO(travis): pre-compute this during training from the full training dataset.
baseline = [torch.unsqueeze(torch.mean(t, dim=0), 0) for t in sample_encoded]

# Configure the explainer, which includes wrapping the model so its interface conforms to
# the format expected by Captum.
target_feature_name = get_feature_name(model, target)
explanation_model = WrapperModule(model.model, target_feature_name)
explainer = IntegratedGradients(explanation_model)

# Lookup from column name to output feature
output_feature_map = {feature["column"]: feature for feature in model.config["output_features"]}

# The second dimension of the attribution tensor corresponds to the cardinality
# of the output feature. For regression (number) this is 1, for binary 2, and
# for category it is the vocab size.
vocab_size = 1
is_category_target = output_feature_map[target_feature_name][TYPE] == CATEGORY
if is_category_target:
vocab_size = model.training_set_metadata[target_feature_name]["vocab_size"]

# Compute attribution for each possible output feature label separately.
attribution_by_label = []
expected_values = []
for target_idx in range(vocab_size):
attribution, delta = explainer.attribute(
tuple(inputs_encoded),
baselines=tuple(baseline),
target=target_idx if is_category_target else None,
internal_batch_size=model.config["trainer"]["batch_size"],
return_convergence_delta=True,
)

# Attribution over the feature embeddings returns a vector with the same
# dimensions, so take the sum over this vector in order to return a single
# floating point attribution value per input feature.
attribution = np.array([t.detach().numpy().sum(1) for t in attribution])
attribution_by_label.append(attribution)

# The convergence delta is given per row, so take the mean to compute the
# average delta for the feature.
# TODO(travis): this isn't really the expected value as it is for shap, so
# find a better name.
expected_value = delta.detach().numpy().mean()
expected_values.append(expected_value)

# For binary outputs, add an extra attribution for the negative class (false).
is_binary_target = output_feature_map[target_feature_name][TYPE] == BINARY
if is_binary_target:
attribution_by_label.append(attribution_by_label[0] * -1)
expected_values.append(expected_values[0] * -1)

# Stack the attributions into a single tensor of shape:
# [batch_size, output_feature_cardinality, num_input_features]
attribution = np.stack(attribution_by_label, axis=1)
attribution = np.transpose(attribution, (2, 1, 0))

# Add in predictions as part of the result.
pred_df = model.predict(inputs_df, return_type=dict)[0]
preds = np.array(pred_df[target_feature_name]["predictions"])

return attribution, expected_values, preds
38 changes: 38 additions & 0 deletions ludwig/explain/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pandas as pd

from ludwig.api import LudwigModel


def filter_cols(df, cols):
cols = {c.lower() for c in cols}
retain_cols = [c for c in df.columns if c.lower() in cols]
return df[retain_cols]


def prepare_data(model: LudwigModel, inputs_df: pd.DataFrame, sample_df: pd.DataFrame, target: str):
feature_cols = [feature["column"] for feature in model.config["input_features"]]
target_feature_name = get_feature_name(model, target)

inputs_df = filter_cols(inputs_df, feature_cols)
sample_df = filter_cols(sample_df, feature_cols)

return inputs_df, sample_df, feature_cols, target_feature_name


def get_pred_col(preds, target):
t = target.lower()
for c in preds.keys():
if c.lower() == t:
if "probabilities" in preds[c]:
return preds[c]["probabilities"]
else:
return preds[c]["predictions"]
raise ValueError(f"Unable to find target column {t} in {preds.keys()}")


def get_feature_name(model: LudwigModel, target: str) -> str:
t = target.lower()
for c in model.training_set_metadata.keys():
if c.lower() == t:
return c
raise ValueError(f"Unable to find target column {t} in {model.training_set_metadata.keys()}")
40 changes: 16 additions & 24 deletions ludwig/hyperopt/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import ray
from packaging import version
from ray import tune
from ray.tune import register_trainable, Stopper
from ray.tune.schedulers.resource_changing_scheduler import DistributeResources, ResourceChangingScheduler
from ray.tune.suggest import BasicVariantGenerator, ConcurrencyLimiter
from ray.tune.utils import wait_for_gpu
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.util.queue import Queue as RayQueue

from ludwig.api import LudwigModel
from ludwig.backend import initialize_backend, RAY
Expand All @@ -30,32 +38,16 @@
from ludwig.utils.fs_utils import has_remote_protocol
from ludwig.utils.misc_utils import get_from_registry

logger = logging.getLogger(__name__)

try:
import ray
from ray import tune
from ray.tune import register_trainable, Stopper
from ray.tune.schedulers.resource_changing_scheduler import DistributeResources, ResourceChangingScheduler
from ray.tune.suggest import BasicVariantGenerator, ConcurrencyLimiter

_ray_114 = version.parse(ray.__version__) >= version.parse("1.14")
if _ray_114:
from ray.tune.search import SEARCH_ALG_IMPORT
from ray.tune.syncer import get_node_to_storage_syncer, SyncConfig
else:
from ray.tune.syncer import get_cloud_sync_client
from ray.tune.suggest import SEARCH_ALG_IMPORT
_ray_114 = version.parse(ray.__version__) >= version.parse("1.14")
if _ray_114:
from ray.tune.search import SEARCH_ALG_IMPORT
from ray.tune.syncer import get_node_to_storage_syncer, SyncConfig
else:
from ray.tune.suggest import SEARCH_ALG_IMPORT
from ray.tune.syncer import get_cloud_sync_client

from ray.tune.utils import wait_for_gpu
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.util.queue import Queue as RayQueue

except ImportError as e:
logger.warning(f"ImportError (execution.py) failed to import ray with error: \n\t{e}")
ray = None
Stopper = object
get_horovod_kwargs = None
logger = logging.getLogger(__name__)


try:
Expand Down
5 changes: 4 additions & 1 deletion ludwig/hyperopt/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ludwig.constants import COMBINED, EXECUTOR, HYPEROPT, LOSS, MINIMIZE, TEST, TRAINING, TYPE, VALIDATION
from ludwig.data.split import get_splitter
from ludwig.features.feature_registries import output_type_registry
from ludwig.hyperopt.execution import executor_registry, get_build_hyperopt_executor, RayTuneExecutor
from ludwig.hyperopt.results import HyperoptResults
from ludwig.hyperopt.utils import print_hyperopt_results, save_hyperopt_stats, should_tune_preprocessing
from ludwig.utils.defaults import default_random_seed, merge_with_defaults
Expand Down Expand Up @@ -162,6 +161,8 @@ def hyperopt(
:return: (List[dict]) List of results for each trial, ordered by
descending performance on the target metric.
"""
from ludwig.hyperopt.execution import get_build_hyperopt_executor, RayTuneExecutor

# check if config is a path or a dict
if isinstance(config, str): # assume path
with open_file(config, "r") as def_file:
Expand Down Expand Up @@ -359,6 +360,8 @@ def hyperopt(


def update_hyperopt_params_with_defaults(hyperopt_params):
from ludwig.hyperopt.execution import executor_registry

set_default_value(hyperopt_params, EXECUTOR, {})
set_default_value(hyperopt_params, "split", VALIDATION)
set_default_value(hyperopt_params, "output_feature", COMBINED)
Expand Down
12 changes: 10 additions & 2 deletions ludwig/utils/dataframe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,17 @@ def is_dask_lib(df_lib) -> bool:
return df_lib.__name__ == DASK_MODULE_NAME


def is_dask_backend(backend: "Backend") -> bool: # noqa: F821
def is_dask_backend(backend: Optional["Backend"]) -> bool: # noqa: F821
"""Returns whether the backend's dataframe is dask."""
return is_dask_lib(backend.df_engine.df_lib)
return backend is not None and is_dask_lib(backend.df_engine.df_lib)


def is_dask_df(df: DataFrame, backend: Optional["Backend"]) -> bool: # noqa: F821
if is_dask_backend(backend):
import dask.dataframe as dd

return isinstance(df, dd.DataFrame)
return False


def flatten_df(df: DataFrame, backend: "Backend") -> Tuple[DataFrame, Dict[str, Tuple]]: # noqa: F821
Expand Down
1 change: 1 addition & 0 deletions requirements_explain.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
captum
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
with open(path.join(here, "requirements_tree.txt"), encoding="utf-8") as f:
extra_requirements["tree"] = [line.strip() for line in f if line]

with open(path.join(here, "requirements_explain.txt"), encoding="utf-8") as f:
extra_requirements["explain"] = [line.strip() for line in f if line]

extra_requirements["full"] = [item for sublist in extra_requirements.values() for item in sublist]

with open(path.join(here, "requirements_test.txt"), encoding="utf-8") as f:
Expand Down
Loading

0 comments on commit ca7db50

Please sign in to comment.