Skip to content

Commit

Permalink
Add L-BFGS optimizer (#2478)
Browse files Browse the repository at this point in the history
  • Loading branch information
jppgks authored Sep 13, 2022
1 parent dbdf07a commit 5db9190
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 0 deletions.
40 changes: 40 additions & 0 deletions examples/lbfgs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
input_features:
- name: RESOURCE
type: category
- name: MGR_ID
type: category
- name: ROLE_ROLLUP_1
type: category
- name: ROLE_ROLLUP_2
type: category
- name: ROLE_DEPTNAME
type: category
- name: ROLE_TITLE
type: category
- name: ROLE_FAMILY_DESC
type: category
- name: ROLE_FAMILY
type: category
- name: ROLE_CODE
type: category
output_features:
- name: ACTION
type: binary
preprocessing:
split:
type: fixed
defaults:
category:
encoder:
type: sparse
trainer:
batch_size: 32769 # entire training set
train_steps: 1
steps_per_checkpoint: 1
learning_rate: 1
regularization_lambda: 0.0000057
optimizer:
type: lbfgs
max_iter: 100
tolerance_grad: 0.0001
history_size: 10
30 changes: 30 additions & 0 deletions examples/lbfgs/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import logging

import pandas as pd

from ludwig.api import LudwigModel
from ludwig.datasets import amazon_employee_access_challenge

df = amazon_employee_access_challenge.load()

model = LudwigModel(config="config.yaml", logging_level=logging.INFO)

training_statistics, preprocessed_data, output_directory = model.train(
df,
skip_save_processed_input=True,
skip_save_log=True,
skip_save_progress=True,
skip_save_training_description=True,
skip_save_training_statistics=True,
)

# Predict on unlabeled test
model.config["preprocessing"] = {}
unlabeled_test = df[df.split == 2].reset_index(drop=True)
preds, _ = model.predict(unlabeled_test)

# Save predictions to csv
action = preds.ACTION_probabilities_True
submission = pd.merge(unlabeled_test.reset_index(drop=True).id.astype(int), action, left_index=True, right_index=True)
submission.rename(columns={"ACTION_probabilities_True": "Action", "id": "Id"}, inplace=True)
submission.to_csv("submission.csv", index=False)
35 changes: 35 additions & 0 deletions ludwig/schema/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
create_cond,
FloatRange,
FloatRangeTupleDataclassField,
Integer,
NonNegativeFloat,
StringOptions,
unload_jsonschema_from_marshmallow_class,
Expand Down Expand Up @@ -73,6 +74,40 @@ class SGDOptimizerConfig(BaseOptimizerConfig):
nesterov: bool = Boolean(default=False, description="Enables Nesterov momentum.")


@register_optimizer(name="lbfgs")
@dataclass
class LBFGSOptimizerConfig(BaseOptimizerConfig):
"""Parameters for stochastic gradient descent."""

optimizer_class: ClassVar[torch.optim.Optimizer] = torch.optim.LBFGS
"""Points to `torch.optim.LBFGS`."""

type: str = StringOptions(["lbfgs"], default="lbfgs", allow_none=False)
"""Must be 'lbfgs' - corresponds to name in `ludwig.modules.optimization_modules.optimizer_registry` (default:
'lbfgs')"""

# Defaults taken from https://pytorch.org/docs/stable/generated/torch.optim.LBFGS.html#torch.optim.LBFGS
lr: float = NonNegativeFloat(default=1, description="Learning rate.")
max_iter: int = Integer(default=20, description="Maximum number of iterations per optimization step.")
max_eval: int = Integer(
default=None,
allow_none=True,
description="Maximum number of function evaluations per optimization step. Default: `max_iter` * 1.25.",
)
tolerance_grad: float = NonNegativeFloat(
default=1e-07, description="Termination tolerance on first order optimality."
)
tolerance_change: float = NonNegativeFloat(
default=1e-09, description="Termination tolerance on function value/parameter changes."
)
history_size: int = Integer(default=100, description="Update history size.")
line_search_fn: str = StringOptions(
["strong_wolfe"],
default=None,
description="Line search function to use.",
)


@register_optimizer(name="adam")
@dataclass
class AdamOptimizerConfig(BaseOptimizerConfig):
Expand Down
23 changes: 23 additions & 0 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,29 @@ def train_step(
Returns:
A tuple of the loss tensor and a dictionary of loss for every output feature.
"""
if isinstance(self.optimizer, torch.optim.LBFGS):
# NOTE: Horovod is not supported for L-BFGS.

def closure():
# Allows L-BFGS to reevaluate the loss function
self.optimizer.zero_grad()
model_outputs = self.model((inputs, targets))
loss, all_losses = self.model.train_loss(
targets, model_outputs, self.regularization_type, self.regularization_lambda
)
loss.backward()
return loss

self.optimizer.step(closure)

# Obtain model predictions and loss
model_outputs = self.model((inputs, targets))
loss, all_losses = self.model.train_loss(
targets, model_outputs, self.regularization_type, self.regularization_lambda
)

return loss, all_losses

self.optimizer.zero_grad()

# Obtain model predictions and loss
Expand Down

0 comments on commit 5db9190

Please sign in to comment.