diff --git a/inference/core/workflows/core_steps/visualizations/common/base_colorable.py b/inference/core/workflows/core_steps/visualizations/common/base_colorable.py index 2fe26c356b..d37012d536 100644 --- a/inference/core/workflows/core_steps/visualizations/common/base_colorable.py +++ b/inference/core/workflows/core_steps/visualizations/common/base_colorable.py @@ -18,6 +18,24 @@ ) from inference.core.workflows.prototypes.block import BlockResult +_MPL_PALETTES_R = { + "Greys_R", + "Purples_R", + "Blues_R", + "Greens_R", + "Oranges_R", + "Reds_R", + "Wistia", + "Pastel1", + "Pastel2", + "Paired", + "Accent", + "Dark2", + "Set1", + "Set2", + "Set3", +} + class ColorableVisualizationManifest(PredictionsVisualizationManifest, ABC): color_palette: Union[ @@ -124,23 +142,7 @@ def getPalette(self, color_palette, palette_size, custom_colors): else: palette_name = color_palette.replace("Matplotlib ", "") - if palette_name in [ - "Greys_R", - "Purples_R", - "Blues_R", - "Greens_R", - "Oranges_R", - "Reds_R", - "Wistia", - "Pastel1", - "Pastel2", - "Paired", - "Accent", - "Dark2", - "Set1", - "Set2", - "Set3", - ]: + if palette_name in _MPL_PALETTES_R: palette_name = palette_name.capitalize() else: palette_name = palette_name.lower() diff --git a/inference/core/workflows/core_steps/visualizations/halo/v1.py b/inference/core/workflows/core_steps/visualizations/halo/v1.py index fee318d842..71cb7b3648 100644 --- a/inference/core/workflows/core_steps/visualizations/halo/v1.py +++ b/inference/core/workflows/core_steps/visualizations/halo/v1.py @@ -103,17 +103,12 @@ def getAnnotator( opacity: float, kernel_size: int, ) -> sv.annotators.base.BaseAnnotator: - key = "_".join( - map( - str, - [ - color_palette, - palette_size, - color_axis, - opacity, - kernel_size, - ], - ) + key = ( + color_palette, + palette_size, + color_axis, + opacity, + kernel_size, ) if key not in self.annotatorCache: