Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple Target Prediction Plotting Bug #1314

Open
terbed opened this issue May 26, 2023 · 0 comments
Open

Multiple Target Prediction Plotting Bug #1314

terbed opened this issue May 26, 2023 · 0 comments

Comments

@terbed
Copy link

terbed commented May 26, 2023

  • PyTorch-Forecasting version: 1.0.0
  • PyTorch version: 2.0
  • Python version: 3.9
  • Operating System: Ubuntu

Description:

When calling the plot_prediction() function in PytorchForecasting with multiple targets, the function reuses the same axes for each target. This behavior results in overlapped plots for different targets, rather than separate plots for each target.

To reproduce:

  • Train a PytorchForecasting model with multiple targets.
  • Call the plot_prediction() function with an output dictionary that includes multiple targets.

The faulty code part

The bug is in the base model plot prediction function:
def plot_prediction()

def plot_prediction(
    self,
    x: Dict[str, torch.Tensor],
    out: Dict[str, torch.Tensor],
    idx: int = 0,
    add_loss_to_title: Union[Metric, torch.Tensor, bool] = False,
    show_future_observed: bool = True,
    ax=None,
    quantiles_kwargs: Dict[str, Any] = {},
    prediction_kwargs: Dict[str, Any] = {},
) -> plt.Figure:

    #...

    # for each target, plot
    figs = []
    for y_raw, y_hat, y_quantile, encoder_target, decoder_target in zip(
        y_raws, y_hats, y_quantiles, encoder_targets, decoder_targets
    ):
        # ...

        # create figure
        if ax is None:
            fig, ax = plt.subplots()
        else:
            fig = ax.get_figure()
        
        # ...

        figs.append(fig)
    
    return figs

Expected behavior:

Each target should be plotted on a separate figure.

Actual behavior:

All targets are plotted on the same figure, resulting in overlapped plots.

Solution:

In the above snippet, the variable ax should be updated within the loop over targets but instead after the first target, the same ax is reused (as ax is no longer None). A possible solution and fix for the problem:

    def plot_prediction(
        self,
        x: Dict[str, torch.Tensor],
        out: Dict[str, torch.Tensor],
        idx: int = 0,
        add_loss_to_title: Union[Metric, torch.Tensor, bool] = False,
        show_future_observed: bool = True,
        ax=None,
        quantiles_kwargs: Dict[str, Any] = {},
        prediction_kwargs: Dict[str, Any] = {},
    ) -> plt.Figure:

        # ...
        # for each target, plot
        figs = []
        ax_provided = ax is not None
        for y_raw, y_hat, y_quantile, encoder_target, decoder_target in zip(
            y_raws, y_hats, y_quantiles, encoder_targets, decoder_targets
        ):

            # ...
            # create figure
            if (ax is None) or (not ax_provided):
                fig, ax = plt.subplots()
            else:
                fig = ax.get_figure()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant