Skip to content

Commit

Permalink
enh: Implements InferenceModule as a pipelined module with separate…
Browse files Browse the repository at this point in the history
… preprocessor, predictor, and postprocessor modules (#2105)

* Adding inference pipeline with seperate pre-processing, predict and post-processing modules

* Update to flatten outputs from predict consistent to support triton

* inference module refactor

* add back InferenceLudwigModel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unify modules into inference.py

* cleaned up inaccurate documentation

* clean up

* clean up type hints and update InferenceLudwigModel

* clean up type hint; passes test_torchscript.py

* added typing to inference module for clarity

* remove inference_module_file_name constant

* unified predict module with postproc

* removed InferencePredictor entirely

* add back the old inference module

* add back training set metadata

* revert change to predict module, move feature filtering to postproc

* cleanup inference_module_v0

* cleanup

* adds device placement to InferenceLudwigModel

* adds ability to save/load torchscript on particular devices

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* allows saving torchscript with dict of devices from api.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* correct device inputs

* refactor to expose inference stages (prep for triton refactor)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove magic 'cpu' string

* remove extraneous constants

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add from_directory classmethod for e2e users

* merge

* merge InferenceModule and InferenceLudwigModel

* add comment

* revert small change

* cleanup

* add to_torchscript functionality

* cleanup

* pushes device logic down into inference stages

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* move device placement upstream to inference module to ensure stage modules are performant

* adds logs for device placement experiments

* removes logs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove stage_to_dict

* clean up how we get input device in predictor_forward

* first commit

* wip

* updated interfaces

* postproc GPU

* add intelligent device placement

* clean up device api

* revert flatten op in inference_module_v0

* remove dtype workaround

* benchmarking code

* add DEVICE constant as good default for loading/saving

* added helpful logging and style

* cleanup

* cleanup, adding docstrings

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* docstring

Co-authored-by: Geoffrey Angus <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 27, 2022
1 parent a587181 commit c26e81a
Show file tree
Hide file tree
Showing 19 changed files with 588 additions and 165 deletions.
68 changes: 45 additions & 23 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from ludwig.data.preprocessing import load_metadata, preprocess_for_prediction, preprocess_for_training
from ludwig.features.feature_registries import update_config_with_metadata
from ludwig.globals import (
INFERENCE_MODULE_FILE_NAME,
LUDWIG_VERSION,
MODEL_HYPERPARAMETERS_FILE_NAME,
MODEL_WEIGHTS_FILE_NAME,
Expand All @@ -65,7 +64,7 @@
)
from ludwig.models.calibrator import Calibrator
from ludwig.models.ecd import ECD
from ludwig.models.inference import InferenceModule
from ludwig.models.inference import InferenceModule, save_ludwig_model_for_inference
from ludwig.models.predictor import (
calculate_overall_stats,
print_evaluation_stats,
Expand All @@ -89,7 +88,8 @@
from ludwig.utils.fs_utils import makedirs, open_file, path_exists, upload_output_directory
from ludwig.utils.misc_utils import get_file_names, get_output_directory, set_saved_weights_in_checkpoint_flag
from ludwig.utils.print_utils import print_boxed
from ludwig.utils.torch_utils import get_torch_device
from ludwig.utils.torch_utils import DEVICE, get_torch_device
from ludwig.utils.types import TorchDevice

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1471,35 +1471,57 @@ def save_config(self, save_path: str) -> None:
model_hyperparameters_path = os.path.join(save_path, MODEL_HYPERPARAMETERS_FILE_NAME)
save_json(model_hyperparameters_path, self.config)

def to_torchscript(self, model_only: bool = False):
"""Converts the trained LudwigModule, including preprocessing and postprocessing, to Torchscript.
The scripted module takes in a `Dict[str, Union[List[str], Tensor]]` as input.
More specifically, for every input feature, we provide either a Tensor of batch_size inputs, a list of Tensors
batch_size in length, or a list of strings batch_size in length.
Note that the dimensions of all Tensors and lengths of all lists must match.
Similarly, the output will be a dictionary of dictionaries, where each feature has its own dictionary of
outputs. The outputs will be a list of strings for predictions with string types, while other outputs will be
tensors of varying dimensions for probabilities, logits, etc.
def to_torchscript(
self,
model_only: bool = False,
device: Optional[TorchDevice] = None,
):
"""Converts the trained model to Torchscript.
Args:
model_only (bool, optional): If True, only the ECD model will be converted to Torchscript. Else,
preprocessing and postprocessing will also be converted to Torchscript.
preprocessing and postprocessing steps will also be converted to Torchscript.
device (TorchDevice, optional): If None, the model will be converted to Torchscript on the same device to
ensure maximum model parity.
Returns:
A torch.jit.ScriptModule that can be used to predict on a dictionary of inputs.
"""
if device is None:
device = DEVICE

self._check_initialization()
if model_only:
return self.model.to_torchscript()
return self.model.to_torchscript(device)
else:
inference_module = InferenceModule(self.model, self.config, self.training_set_metadata)
inference_module = InferenceModule.from_ludwig_model(
self.model, self.config, self.training_set_metadata, device=device
)
return torch.jit.script(inference_module)

def save_torchscript(self, save_path: str, model_only: bool = False):
"""Saves the Torchscript model to disk."""
inference_module = self.to_torchscript(model_only=model_only)
inference_module.save(os.path.join(save_path, INFERENCE_MODULE_FILE_NAME))
def save_torchscript(
self,
save_path: str,
model_only: bool = False,
device: Optional[TorchDevice] = None,
):
"""Saves the Torchscript model to disk.
save_path (str): The path to the directory where the model will be saved. model_only (bool, optional): If True,
only the ECD model will be converted to Torchscript. Else, the preprocessing and postprocessing steps will
also be converted to Torchscript. device (TorchDevice, optional): If None, the model will be converted to
Torchscript on the same device to ensure maximum model parity.
"""
if device is None:
device = DEVICE

save_ludwig_model_for_inference(
save_path,
self.model,
self.config,
self.training_set_metadata,
model_only=model_only,
device=device,
)

def _check_initialization(self):
if self.model is None or self.config is None or self.training_set_metadata is None:
Expand Down
4 changes: 3 additions & 1 deletion ludwig/features/audio_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,9 @@ def _transform_to_feature(

feature_length = audio_feature.shape[0]
broadcast_feature_length = min(feature_length, max_length)
audio_feature_padded = torch.full((max_length, feature_dim), padding_value, dtype=torch.float32)
audio_feature_padded = torch.full(
(max_length, feature_dim), padding_value, dtype=torch.float32, device=audio_feature.device
)
audio_feature_padded[:broadcast_feature_length, :] = audio_feature[:max_length, :]

return audio_feature_padded
Expand Down
11 changes: 6 additions & 5 deletions ludwig/features/binary_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,19 @@ def __init__(self, metadata: Dict[str, Any]):
self.predictions_key = PREDICTIONS
self.probabilities_key = PROBABILITIES

def forward(self, preds: Dict[str, torch.Tensor]) -> Dict[str, Any]:
predictions = preds[self.predictions_key]
def forward(self, preds: Dict[str, torch.Tensor], feature_name: str) -> Dict[str, Any]:
predictions = output_feature_utils.get_output_feature_tensor(preds, feature_name, self.predictions_key)
probabilities = output_feature_utils.get_output_feature_tensor(preds, feature_name, self.probabilities_key)

if self.bool2str is not None:
predictions = predictions.to(dtype=torch.int32)
predictions = [self.bool2str.get(pred, self.bool2str[0]) for pred in predictions]

probs = preds[self.probabilities_key]
probs = torch.stack([1 - probs, probs], dim=-1)
probabilities = torch.stack([1 - probabilities, probabilities], dim=-1)

return {
self.predictions_key: predictions,
self.probabilities_key: probs,
self.probabilities_key: probabilities,
}


Expand Down
11 changes: 7 additions & 4 deletions ludwig/features/category_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,19 @@ class _CategoryPostprocessing(torch.nn.Module):
def __init__(self, metadata: Dict[str, Any]):
super().__init__()
self.idx2str = {i: v for i, v in enumerate(metadata["idx2str"])}
self.unk = UNKNOWN_SYMBOL
self.predictions_key = PREDICTIONS
self.probabilities_key = PROBABILITIES
self.unk = ""

def forward(self, preds: Dict[str, torch.Tensor]) -> Dict[str, Any]:
predictions = preds[self.predictions_key]
def forward(self, preds: Dict[str, torch.Tensor], feature_name: str) -> Dict[str, Any]:
predictions = output_feature_utils.get_output_feature_tensor(preds, feature_name, self.predictions_key)
probabilities = output_feature_utils.get_output_feature_tensor(preds, feature_name, self.probabilities_key)

inv_preds = [self.idx2str.get(pred, self.unk) for pred in predictions]

return {
self.predictions_key: inv_preds,
self.probabilities_key: preds[self.probabilities_key],
self.probabilities_key: probabilities,
}


Expand Down
6 changes: 3 additions & 3 deletions ludwig/features/h3_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def forward(self, v: TorchscriptPreprocessingInput) -> torch.Tensor:
components.base_cell,
]
cells_padding: List[int] = [self.h3_padding_value] * (self.max_h3_resolution - len(components.cells))
output = torch.tensor(header + components.cells + cells_padding, dtype=torch.uint8)
output = torch.tensor(header + components.cells + cells_padding, dtype=torch.uint8, device=v.device)
outputs.append(output)

return torch.stack(outputs)
Expand Down Expand Up @@ -111,8 +111,8 @@ def add_feature_data(
):
column = input_df[feature_config[COLUMN]]
if column.dtype == object:
column = column.map(int)
column = column.map(H3FeatureMixin.h3_to_list)
column = backend.df_engine.map_objects(column, int)
column = backend.df_engine.map_objects(column, H3FeatureMixin.h3_to_list)

proc_df[feature_config[PROC_COLUMN]] = backend.df_engine.map_objects(
column, lambda x: np.array(x, dtype=np.uint8)
Expand Down
6 changes: 4 additions & 2 deletions ludwig/features/number_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,10 @@ def __init__(self, metadata: Dict[str, Any]):
self.numeric_transformer = get_transformer(metadata, metadata["preprocessing"])
self.predictions_key = PREDICTIONS

def forward(self, preds: Dict[str, torch.Tensor]) -> Dict[str, Any]:
return {self.predictions_key: self.numeric_transformer.inverse_transform_inference(preds[self.predictions_key])}
def forward(self, preds: Dict[str, torch.Tensor], feature_name: str) -> Dict[str, Any]:
predictions = output_feature_utils.get_output_feature_tensor(preds, feature_name, self.predictions_key)

return {self.predictions_key: self.numeric_transformer.inverse_transform_inference(predictions)}


class _NumberPredict(PredictModule):
Expand Down
8 changes: 4 additions & 4 deletions ludwig/features/sequence_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,10 @@ def __init__(self, metadata: Dict[str, Any]):
self.probabilities_key = PROBABILITIES
self.probability_key = PROBABILITY

def forward(self, preds: Dict[str, torch.Tensor]) -> Dict[str, Any]:
"""Takes a dictionary of tensors and returns a dictionary of tensors."""
pred_predictions = preds[self.predictions_key]
def forward(self, preds: Dict[str, torch.Tensor], feature_name: str) -> Dict[str, Any]:
pred_predictions = output_feature_utils.get_output_feature_tensor(preds, feature_name, self.predictions_key)
pred_probabilities = output_feature_utils.get_output_feature_tensor(preds, feature_name, self.probabilities_key)

predictions: List[List[str]] = []
for sequence in pred_predictions:
sequence_predictions: List[str] = []
Expand All @@ -151,7 +152,6 @@ def forward(self, preds: Dict[str, torch.Tensor]) -> Dict[str, Any]:
sequence_predictions.append(unit_prediction)
predictions.append(sequence_predictions)

pred_probabilities = preds[self.probabilities_key]
probabilities, _ = torch.max(pred_probabilities, dim=-1)
probability = torch.sum(torch.log(probabilities), dim=-1)

Expand Down
6 changes: 3 additions & 3 deletions ludwig/features/set_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def __init__(self, metadata: Dict[str, Any]):
self.probabilities_key = PROBABILITIES
self.unk = UNKNOWN_SYMBOL

def forward(self, preds: Dict[str, torch.Tensor]) -> Dict[str, Any]:
predictions = preds[self.predictions_key]
probabilities = preds[self.probabilities_key]
def forward(self, preds: Dict[str, torch.Tensor], feature_name: str) -> Dict[str, Any]:
predictions = output_feature_utils.get_output_feature_tensor(preds, feature_name, self.predictions_key)
probabilities = output_feature_utils.get_output_feature_tensor(preds, feature_name, self.probabilities_key)

inv_preds: List[List[str]] = []
filtered_probs: List[torch.Tensor] = []
Expand Down
17 changes: 11 additions & 6 deletions ludwig/features/vector_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,22 @@ def forward(self, v: TorchscriptPreprocessingInput) -> torch.Tensor:


class _VectorPostprocessing(torch.nn.Module):
def forward(self, preds: Dict[str, torch.Tensor]) -> Dict[str, Any]:
# Workaround to convert type annotation from Dict[str, torch.Tensor] to Dict[str, Any]
preds_any: Dict[str, Any] = {}
for k, v in preds.items():
preds_any[k] = v
return preds_any
def __init__(self):
super().__init__()
self.predictions_key = PREDICTIONS
self.logits_key = LOGITS

def forward(self, preds: Dict[str, torch.Tensor], feature_name: str) -> Dict[str, Any]:
predictions = output_feature_utils.get_output_feature_tensor(preds, feature_name, self.predictions_key)
logits = output_feature_utils.get_output_feature_tensor(preds, feature_name, self.logits_key)

return {self.predictions_key: predictions, self.logits_key: logits}


class _VectorPredict(PredictModule):
def forward(self, inputs: Dict[str, torch.Tensor], feature_name: str) -> Dict[str, torch.Tensor]:
logits = output_feature_utils.get_output_feature_tensor(inputs, feature_name, self.logits_key)

return {self.predictions_key: logits, self.logits_key: logits}


Expand Down
1 change: 0 additions & 1 deletion ludwig/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

LUDWIG_VERSION = "0.5.3"

INFERENCE_MODULE_FILE_NAME = "inference_module"
MODEL_WEIGHTS_FILE_NAME = "model_weights"
MODEL_HYPERPARAMETERS_FILE_NAME = "model_hyperparameters.json"
TRAIN_SET_METADATA_FILE_NAME = "training_set_metadata.json"
Expand Down
22 changes: 17 additions & 5 deletions ludwig/models/ecd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from ludwig.utils.data_utils import clear_data_cache
from ludwig.utils.metric_utils import get_scalar_from_ludwig_metric
from ludwig.utils.misc_utils import get_from_registry
from ludwig.utils.torch_utils import LudwigModule, reg_loss
from ludwig.utils.torch_utils import DEVICE, LudwigModule, reg_loss
from ludwig.utils.types import TorchDevice

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,14 +89,25 @@ def get_model_size(self) -> int:
total_size += tnsr[1].detach().cpu().numpy().size
return total_size

def to_torchscript(self):
def to_torchscript(self, device: Optional[TorchDevice] = None):
"""Converts the ECD model as a TorchScript model."""
if device is None:
device = DEVICE

self.eval()
model_inputs = self.get_model_inputs()

model_to_script = self.to(device)
model_inputs_to_script = {k: v.to(device) for k, v in model_inputs.items()}
# We set strict=False to enable dict inputs and outputs.
return torch.jit.trace(self, model_inputs, strict=False)
return torch.jit.trace(model_to_script, model_inputs_to_script, strict=False)

def save_torchscript(self, save_path, device: Optional[TorchDevice] = None):
"""Saves the ECD model as a TorchScript model."""
if device is None:
device = DEVICE

def save_torchscript(self, save_path):
traced = self.to_torchscript()
traced = self.to_torchscript(device)
traced.save(save_path)

@property
Expand Down
Loading

0 comments on commit c26e81a

Please sign in to comment.