From 8db9b9792c9cb49c55a642c2595303900e4af9ee Mon Sep 17 00:00:00 2001 From: rengel Date: Tue, 4 Jan 2022 16:06:59 +0100 Subject: [PATCH] Addition to suppress the plots etc This PR is adding the ability to suppress the plots from showing, if the user only wants write them to disk. For the best genes the range on the y-axis can be used from the defined gene_space (if provided). --- pygad.py | 75 ++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 29 deletions(-) diff --git a/pygad.py b/pygad.py index feb2d3f3..32cfb6c7 100644 --- a/pygad.py +++ b/pygad.py @@ -255,7 +255,7 @@ def __init__(self, else: self.valid_parameters = False raise ValueError("The value passed to the 'gene_type' parameter must be either a single integer, floating-point, list, tuple, or numpy.ndarray but ({gene_type_val}) of type {gene_type_type} found.".format(gene_type_val=gene_type, gene_type_type=type(gene_type))) - + # Build the initial population if initial_population is None: if (sol_per_pop is None) or (num_genes is None): @@ -1270,7 +1270,7 @@ def run(self): if self.save_solutions: self.solutions.extend(self.population.copy()) - # If the callback_generation attribute is not None, then cal the callback function after the generation. + # If the callback_generation attribute is not None, then call the callback function after the generation. if not (self.on_generation is None): r = self.on_generation(self) if type(r) is str and r.lower() == "stop": @@ -3121,27 +3121,29 @@ def best_solution(self, pop_fitness=None): return best_solution, best_solution_fitness, best_match_idx - def plot_result(self, - title="PyGAD - Generation vs. Fitness", - xlabel="Generation", - ylabel="Fitness", - linewidth=3, - font_size=14, + def plot_result(self, + title="PyGAD - Generation vs. Fitness", + xlabel="Generation", + ylabel="Fitness", + linewidth=3, + font_size=14, plot_type="plot", color="#3870FF", - save_dir=None): + save_dir=None, + show_fig=True): if not self.suppress_warnings: warnings.warn("Please use the plot_fitness() method instead of plot_result(). The plot_result() method will be removed in the future.") - return self.plot_fitness(title=title, - xlabel=xlabel, - ylabel=ylabel, - linewidth=linewidth, - font_size=font_size, + return self.plot_fitness(title=title, + xlabel=xlabel, + ylabel=ylabel, + linewidth=linewidth, + font_size=font_size, plot_type=plot_type, color=color, - save_dir=save_dir) + save_dir=save_dir, + show_fig=show_fig) def plot_fitness(self, title="PyGAD - Generation vs. Fitness", @@ -3151,7 +3153,8 @@ def plot_fitness(self, font_size=14, plot_type="plot", color="#3870FF", - save_dir=None): + save_dir=None, + show_fig=True): """ Creates, shows, and returns a figure that summarizes how the fitness value evolved by generation. Can only be called after completing at least 1 generation. If no generation is completed, an exception is raised. @@ -3165,6 +3168,7 @@ def plot_fitness(self, plot_type: Type of the plot which can be either "plot" (default), "scatter", or "bar". color: Color of the plot which defaults to "#3870FF". save_dir: Directory to save the figure. + show_fig: shows plot per default, can be set to "False" if plot shall only be saved i.e. Returns the figure. """ @@ -3187,9 +3191,10 @@ def plot_fitness(self, matplotlib.pyplot.ylabel(ylabel, fontsize=font_size) if not save_dir is None: - matplotlib.pyplot.savefig(fname=save_dir, + matplotlib.pyplot.savefig(fname=save_dir, bbox_inches='tight') - matplotlib.pyplot.show() + if show_fig: + matplotlib.pyplot.show() return fig @@ -3201,7 +3206,8 @@ def plot_new_solution_rate(self, font_size=14, plot_type="plot", color="#3870FF", - save_dir=None): + save_dir=None, + show_fig=True): """ Creates, shows, and returns a figure that summarizes the rate of exploring new solutions. This method works only when save_solutions=True in the constructor of the pygad.GA class. @@ -3215,6 +3221,7 @@ def plot_new_solution_rate(self, plot_type: Type of the plot which can be either "plot" (default), "scatter", or "bar". color: Color of the plot which defaults to "#3870FF". save_dir: Directory to save the figure. + show_fig: shows plot per default, can be set to "False" if plot shall only be saved i.e. Returns the figure. """ @@ -3256,22 +3263,25 @@ def plot_new_solution_rate(self, if not save_dir is None: matplotlib.pyplot.savefig(fname=save_dir, bbox_inches='tight') - matplotlib.pyplot.show() + if show_fig: + matplotlib.pyplot.show() return fig - def plot_genes(self, - title="PyGAD - Gene", - xlabel="Gene", - ylabel="Value", - linewidth=3, + def plot_genes(self, + title="PyGAD - Gene", + xlabel="Gene", + ylabel="Value", + linewidth=3, font_size=14, plot_type="plot", graph_type="plot", + range_on_y=False, fill_color="#3870FF", color="black", solutions="all", - save_dir=None): + save_dir=None, + show_fig=True): """ Creates, shows, and returns a figure with number of subplots equal to the number of genes. Each subplot shows the gene value for each generation. @@ -3286,10 +3296,12 @@ def plot_genes(self, font_size: Font size for the labels and title. Defaults to 14. plot_type: Type of the plot which can be either "plot" (default), "scatter", or "bar". graph_type: Type of the graph which can be either "plot" (default), "boxplot", or "histogram". + range_on_y: Derive units for y-axis not from data, but from "gene_space", if definded (default = False). fill_color: Fill color of the graph which defaults to "#3870FF". This has no effect if graph_type="plot". color: Color of the plot which defaults to "black". solutions: Defaults to "all" which means use all solutions. If "best" then only the best solutions are used. save_dir: Directory to save the figure. + show_fig: shows plot per default, can be set to "False" if plot shall only be saved i.e. Returns the figure. """ @@ -3360,11 +3372,16 @@ def plot_genes(self, break if plot_type == "plot": axs[row_idx, col_idx].plot(solutions_to_plot[:, gene_idx], linewidth=linewidth, color=fill_color) + if not self.gene_space is None and range_on_y: + lhs = list(self.gene_space[gene_idx].values()) + # set range only for genes with different "low" and "high" + if lhs[0] != lhs[1]: + axs[row_idx, col_idx].set_ylim([lhs[0], lhs[1]]) elif plot_type == "scatter": axs[row_idx, col_idx].scatter(range(solutions_to_plot.shape[0]), solutions_to_plot[:, gene_idx], linewidth=linewidth, color=fill_color) elif plot_type == "bar": axs[row_idx, col_idx].bar(range(solutions_to_plot.shape[0]), solutions_to_plot[:, gene_idx], linewidth=linewidth, color=fill_color) - axs[row_idx, col_idx].set_xlabel("Gene " + str(gene_idx), fontsize=font_size) + axs[row_idx, col_idx].set_xlabel("Gene " + str(gene_idx), fontsize=font_size) gene_idx += 1 fig.suptitle(title, fontsize=font_size, y=1.001) @@ -3451,8 +3468,8 @@ def plot_genes(self, if not save_dir is None: matplotlib.pyplot.savefig(fname=save_dir, bbox_inches='tight') - - matplotlib.pyplot.show() + if show_fig: + matplotlib.pyplot.show() return fig