You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When calling the plot_prediction() function in PytorchForecasting with multiple targets, the function reuses the same axes for each target. This behavior results in overlapped plots for different targets, rather than separate plots for each target.
To reproduce:
Train a PytorchForecasting model with multiple targets.
Call the plot_prediction() function with an output dictionary that includes multiple targets.
Each target should be plotted on a separate figure.
Actual behavior:
All targets are plotted on the same figure, resulting in overlapped plots.
Solution:
In the above snippet, the variable ax should be updated within the loop over targets but instead after the first target, the same ax is reused (as ax is no longer None). A possible solution and fix for the problem:
def plot_prediction(
self,
x: Dict[str, torch.Tensor],
out: Dict[str, torch.Tensor],
idx: int = 0,
add_loss_to_title: Union[Metric, torch.Tensor, bool] = False,
show_future_observed: bool = True,
ax=None,
quantiles_kwargs: Dict[str, Any] = {},
prediction_kwargs: Dict[str, Any] = {},
) -> plt.Figure:
# ...
# for each target, plot
figs = []
ax_provided = ax is not None
for y_raw, y_hat, y_quantile, encoder_target, decoder_target in zip(
y_raws, y_hats, y_quantiles, encoder_targets, decoder_targets
):
# ...
# create figure
if (ax is None) or (not ax_provided):
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
The text was updated successfully, but these errors were encountered:
Description:
When calling the plot_prediction() function in PytorchForecasting with multiple targets, the function reuses the same axes for each target. This behavior results in overlapped plots for different targets, rather than separate plots for each target.
To reproduce:
The faulty code part
The bug is in the base model plot prediction function:
def plot_prediction()
Expected behavior:
Each target should be plotted on a separate figure.
Actual behavior:
All targets are plotted on the same figure, resulting in overlapped plots.
Solution:
In the above snippet, the variable ax should be updated within the loop over targets but instead after the first target, the same ax is reused (as ax is no longer None). A possible solution and fix for the problem:
The text was updated successfully, but these errors were encountered: