Skip to content

Commit

Permalink
[minor] Add static plotly backend (#1286)
Browse files Browse the repository at this point in the history
* [minor] Add static plotly backend

* [minor] Remove unnecessary arguments to ensure API consistency and set default static display type to svg

* Stop plots from showing twice

* Add tests for plotly-resampler and plotly-static

* Update requirements to include kaleido

* Remove requirements

* Add nbformat via poetry

* add plotly-static to self.set_plotting_backend()

* update self.set_plotting_backend()

---------

Co-authored-by: leoniewgnr <[email protected]>
Co-authored-by: LeonieFreisinger <[email protected]>
  • Loading branch information
3 people authored Apr 27, 2023
1 parent 9f08ff6 commit bc6891f
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 22 deletions.
52 changes: 38 additions & 14 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,14 +1797,15 @@ def set_plotting_backend(self, plotting_backend: str):
plotly-resampler package to accelerate visualizing large data by resampling it. Only supported for
jupyterlab notebooks and vscode notebooks.
* ``plotly``: Use the plotly backend for plotting
* ``plotly-static``: Use the plotly backend to generate static svg
* ``matplotlib``: use matplotlib for plotting
"""
if plotting_backend in ["plotly", "matplotlib", "plotly-resampler"]:
if plotting_backend in ["plotly", "matplotlib", "plotly-resampler", "plotly-static"]:
self.plotting_backend = plotting_backend
log_warning_deprecation_plotly(self.plotting_backend)
else:
raise ValueError(
"The parameter `plotting_backend` must be either 'plotly', 'plotly-resampler' or 'matplotlib'."
"The parameter `plotting_backend` must be either 'plotly', 'plotly-resampler', 'plotly-resampler' or 'matplotlib'."
)

def highlight_nth_step_ahead_of_each_forecast(self, step_number: Optional[int] = None):
Expand Down Expand Up @@ -1860,6 +1861,7 @@ def plot(
environments (colab, pycharm interpreter) plotly-resampler might not properly vizualise the figures.
In this case, consider switching to 'plotly-auto'.
* ``plotly``: Use the plotly backend for plotting
* ``plotly-static``: Use the plotly backend to generate static svg
* ``matplotlib``: use matplotlib for plotting
* (default) None: Plotting backend ist set automatically. Use plotly with resampling for jupyterlab
notebooks and vscode notebooks. Automatically switch to plotly without resampling for all other
Expand Down Expand Up @@ -1929,6 +1931,7 @@ def plot(
figsize=tuple(x * 70 for x in figsize),
highlight_forecast=forecast_in_focus,
resampler_active=plotting_backend == "plotly-resampler",
plotly_static=plotting_backend == "plotly-static",
)
else:
return plot(
Expand Down Expand Up @@ -2047,10 +2050,12 @@ def plot_latest_forecast(
environments (colab, pycharm interpreter) plotly-resampler might not properly vizualise the figures.
In this case, consider switching to 'plotly-auto'.
* ``plotly``: Use the plotly backend for plotting
* ``plotly-static``: Use the plotly backend to generate static svg
* ``matplotlib``: use matplotlib for plotting
** (default) None: Plotting backend ist set automatically. Use plotly with resampling for jupyterlab
notebooks and vscode notebooks. Automatically switch to plotly without resampling for all other
environments.
* (default) None
Returns
-------
matplotlib.axes.Axes
Expand Down Expand Up @@ -2097,6 +2102,7 @@ def plot_latest_forecast(
highlight_forecast=self.highlight_forecast_step_n,
line_per_origin=True,
resampler_active=plotting_backend == "plotly-resampler",
plotly_static=plotting_backend == "plotly-static",
)
else:
return plot(
Expand Down Expand Up @@ -2169,6 +2175,7 @@ def plot_components(
environments (colab, pycharm interpreter) plotly-resampler might not properly vizualise the figures.
In this case, consider switching to 'plotly-auto'.
* ``plotly``: Use the plotly backend for plotting
* ``plotly-static``: Use the plotly backend to generate static svg
* ``matplotlib``: use matplotlib for plotting
* (default) None: Plotting backend ist set automatically. Use plotly with resampling for jupyterlab
notebooks and vscode notebooks. Automatically switch to plotly without resampling for all other
Expand Down Expand Up @@ -2260,6 +2267,7 @@ def plot_components(
df_name=df_name,
one_period_per_season=one_period_per_season,
resampler_active=plotting_backend == "plotly-resampler",
plotly_static=plotting_backend == "plotly-static",
)
else:
return plot_components(
Expand Down Expand Up @@ -2323,6 +2331,7 @@ def plot_parameters(
environments (colab, pycharm interpreter) plotly-resampler might not properly vizualise the figures.
In this case, consider switching to 'plotly-auto'.
* ``plotly``: Use the plotly backend for plotting
* ``plotly-static``: Use the plotly backend to generate static svg
* ``matplotlib``: use matplotlib for plotting
* (default) None: Plotting backend ist set automatically. Use plotly with resampling for jupyterlab
notebooks and vscode notebooks. Automatically switch to plotly without resampling for all other
Expand All @@ -2331,7 +2340,6 @@ def plot_parameters(
Note
----
For multiple time series and local modeling of at least one component, the df_name parameter is required.
quantile : float
The quantile for which the model parameters are to be plotted
Expand Down Expand Up @@ -2405,17 +2413,33 @@ def plot_parameters(

log_warning_deprecation_plotly(plotting_backend)
if plotting_backend.startswith("plotly"):
return plot_parameters_plotly(
m=self,
quantile=quantile,
weekly_start=weekly_start,
yearly_start=yearly_start,
figsize=tuple(x * 70 for x in figsize) if figsize else (700, 210),
df_name=valid_plot_configuration["df_name"],
plot_configuration=valid_plot_configuration,
forecast_in_focus=forecast_in_focus,
resampler_active=plotting_backend == "plotly-resampler",
)
if plotting_backend == "plotly-static":
fig = plot_parameters_plotly(
m=self,
quantile=quantile,
weekly_start=weekly_start,
yearly_start=yearly_start,
figsize=tuple(x * 70 for x in figsize) if figsize else (700, 210),
df_name=valid_plot_configuration["df_name"],
plot_configuration=valid_plot_configuration,
forecast_in_focus=forecast_in_focus,
resampler_active=plotting_backend == "plotly-resampler",
plotly_static=plotting_backend == "plotly-static",
)
fig.show("svg")
else:
return plot_parameters_plotly(
m=self,
quantile=quantile,
weekly_start=weekly_start,
yearly_start=yearly_start,
figsize=tuple(x * 70 for x in figsize) if figsize else (700, 210),
df_name=valid_plot_configuration["df_name"],
plot_configuration=valid_plot_configuration,
forecast_in_focus=forecast_in_focus,
resampler_active=plotting_backend == "plotly-resampler",
plotly_static=plotting_backend == "plotly-static",
)
else:
return plot_parameters(
m=self,
Expand Down
10 changes: 10 additions & 0 deletions neuralprophet/plot_forecast_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def plot(
line_per_origin=False,
figsize=(700, 210),
resampler_active=False,
plotly_static=False,
):
"""
Plot the NeuralProphet forecast
Expand All @@ -73,6 +74,8 @@ def plot(
Width, height in inches.
resampler_active : bool
Flag whether to activate the plotly-resampler
plotly_static: bool
Flag whether to generate a static svg image
Returns
-------
Expand Down Expand Up @@ -227,6 +230,8 @@ def plot(
)
fig = go.Figure(data=data, layout=layout)
unregister_plotly_resampler()
if plotly_static:
fig = fig.show("svg")
return fig


Expand All @@ -238,6 +243,7 @@ def plot_components(
one_period_per_season=False,
figsize=(700, 210),
resampler_active=False,
plotly_static=False,
):
"""
Plot the NeuralProphet forecast components.
Expand All @@ -258,6 +264,8 @@ def plot_components(
Width, height in inches.
resampler_active : bool
Flag whether to activate the plotly-resampler
plotly_static: bool
Flag whether to generate a static svg image
Returns
-------
Expand Down Expand Up @@ -339,6 +347,8 @@ def plot_components(
for ax in multiplicative_axes:
ax = set_y_as_percent(ax)
unregister_plotly_resampler()
if plotly_static:
fig = fig.show("svg")
return fig


Expand Down
3 changes: 3 additions & 0 deletions neuralprophet/plot_model_parameters_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,7 @@ def plot_parameters(
df_name=None,
forecast_in_focus=None,
resampler_active=False,
plotly_static=False,
):
"""Plot the parameters that the model is composed of, visually.
Expand Down Expand Up @@ -860,6 +861,8 @@ def plot_parameters(
None (default): plot self.highlight_forecast_step_n by default
resampler_active : bool
Flag whether to activate the plotly-resampler
plotly_static: bool
Flag whether to generate a static svg image
Returns:
Plotly figure
Expand Down
14 changes: 7 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ matplotlib = "^3.5.3"
numpy = ">=1.22.0,<1.24.0"
pandas = "^1.3.5"
plotly = "^5.13.1"
kaleido = "^0.2.1"
plotly-resampler = "^0.8.3.1"
pytorch-lightning = "^1.9.4"
tensorboard = "^2.11.2"
torch = "^1.13.1"
torchmetrics = "^0.11.3"
typing-extensions = "^4.5.0"
nbformat = ">=4.2.0"

[tool.poetry.group.dev.dependencies]
black = { extras = ["jupyter"], version = "^23.1.0" }
Expand Down
3 changes: 2 additions & 1 deletion tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@

# plot tests cover both plotting backends
decorator_input = ["plotting_backend", [("matplotlib"), ("plotly")]]
decorator_input_extended = ["plotting_backend", [("matplotlib"), ("plotly"), ("plotly-static"), ("plotly-resampler")]]


@pytest.mark.parametrize(*decorator_input)
@pytest.mark.parametrize(*decorator_input_extended)
def test_plot(plotting_backend):
log.info(f"testing: Basic plotting with forecast in focus with {plotting_backend}")
df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
Expand Down

0 comments on commit bc6891f

Please sign in to comment.