Skip to content

Commit

Permalink
ensure that there are enough colors to match the score index in visua… (
Browse files Browse the repository at this point in the history
  • Loading branch information
thelinuxkid authored Aug 31, 2023
1 parent b802cdd commit 63f4924
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions ludwig/utils/visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def compare_classifiers_plot(
width = 0.8 / num_metrics if num_metrics > 1 else 0.4
ticks = np.arange(len(scores[0]))

colors = plt.get_cmap("tab10").colors
if num_metrics <= 10:
colors = plt.get_cmap("tab10").colors
else:
colors = plt.get_cmap("tab20").colors
if adaptive:
maximum = max(max(score) for score in scores)
else:
Expand Down Expand Up @@ -211,8 +214,14 @@ def compare_classifiers_line_plot(
filename=None,
callbacks=None,
):
assert len(scores) > 0

sns.set_style("whitegrid")
colors = plt.get_cmap("tab10").colors

if len(scores) <= 10:
colors = plt.get_cmap("tab10").colors
else:
colors = plt.get_cmap("tab20").colors

fig, ax = plt.subplots()

Expand Down Expand Up @@ -267,7 +276,10 @@ def compare_classifiers_multiclass_multimetric_plot(
width = 0.9 / len(scores)
ticks = np.arange(len(scores[0]))

colors = plt.get_cmap("tab10").colors
if len(scores) <= 10:
colors = plt.get_cmap("tab10").colors
else:
colors = plt.get_cmap("tab20").colors
ax.set_xlabel("class")
ax.set_xticks(ticks + width)
if labels is not None:
Expand Down

0 comments on commit 63f4924

Please sign in to comment.