sah.plot

Description

This module contains the plot() function, used to plot sah.classes.Spectra data, containing optional sah.classes.Plotting parameters.

It is used as sah.plot(Spectra)


  1"""
  2# Description
  3
  4This module contains the `plot()` function,
  5used to plot `sah.classes.Spectra` data,
  6containing optional `sah.classes.Plotting` parameters.
  7
  8It is used as `sah.plot(Spectra)`
  9
 10---
 11"""
 12
 13
 14import matplotlib.pyplot as plt
 15from .classes import *
 16
 17
 18def plot(spectra:Spectra):
 19    """Plots a `spectra`.
 20
 21    Optional `sah.classes.Plotting` attributes can be used.
 22    """
 23    # To clean the filename
 24    strings_to_delete_from_name = ['.csv', '.dat', '.txt', '_INS', '_ATR', '_FTIR', '_temp', '_RAMAN', '_Raman', '/data/', 'data/', '/csv/', 'csv/', '/INS/', 'INS/', '/FTIR/', 'FTIR/', '/ATR/', 'ATR/', '_smooth', '_smoothed', '_subtracted', '_cellsubtracted']
 25    # Avoid modifying the original Spectra object
 26    sdata = deepcopy(spectra)
 27    # Matplotlib stuff
 28    if hasattr(sdata, 'plotting') and sdata.plotting.figsize:
 29        fig, ax = plt.subplots(figsize=sdata.plotting.figsize)
 30    else:
 31        fig, ax = plt.subplots()
 32    # Optional scaling factor
 33    scale_factor = sdata.plotting.scaling if hasattr(sdata, 'plotting') and sdata.plotting.scaling else 1.0
 34    # Calculate Y limits
 35    calculated_low_ylim, calculated_top_ylim = _get_ylimits(sdata)
 36    low_ylim = calculated_low_ylim if not hasattr(sdata, 'plotting') or sdata.plotting.ylim[0] is None else sdata.plotting.ylim[0]
 37    top_ylim = calculated_top_ylim if not hasattr(sdata, 'plotting') or sdata.plotting.ylim[1] is None else sdata.plotting.ylim[1]
 38    # Get some plotting parameters
 39    low_xlim = None
 40    top_xlim = None
 41    if getattr(sdata, 'plotting', None) is not None:
 42        title = sdata.plotting.title
 43        low_xlim = sdata.plotting.xlim[0]
 44        top_xlim = sdata.plotting.xlim[1]
 45        xlabel = sdata.plotting.xlabel if sdata.plotting.xlabel is not None else sdata.dfs[0].columns[0]
 46        ylabel = sdata.plotting.ylabel if sdata.plotting.ylabel is not None else sdata.dfs[0].columns[1]
 47    else:
 48        title = sdata.comment
 49    # Set plot offset
 50    number_of_plots = len(sdata.dfs)
 51    height = (top_ylim - low_ylim)
 52    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
 53    if hasattr(sdata, 'plotting') and sdata.plotting.viridis:
 54        colors = plt.cm.viridis(np.linspace(0, 1, number_of_plots+1))  # +1 to avoid the lighter tones
 55    if hasattr(sdata, 'plotting') and sdata.plotting.offset is True:
 56        for i, df in enumerate(sdata.dfs):
 57            reverse_i = (number_of_plots - 1) - i
 58            df[df.columns[1]] = df[df.columns[1]] + (reverse_i * height / scale_factor)
 59    elif hasattr(sdata, 'plotting') and (isinstance(sdata.plotting.offset, float) or isinstance(sdata.plotting.offset, int)):
 60        offset = sdata.plotting.offset
 61        for i, df in enumerate(sdata.dfs):
 62            reverse_i = (number_of_plots - 1) - i
 63            df[df.columns[1]] = df[df.columns[1]] + (reverse_i * offset / scale_factor)
 64    _, calculated_top_ylim = _get_ylimits(sdata)
 65    top_ylim = calculated_top_ylim if not hasattr(sdata, 'plotting') or sdata.plotting.ylim[1] is None else sdata.plotting.ylim[1]
 66    # Set legend
 67    if hasattr(sdata, 'plotting') and hasattr(sdata.plotting, 'legend'):
 68        if sdata.plotting.legend == False:
 69            for df in sdata.dfs:
 70                df.plot(x=df.columns[0], y=df.columns[1], color=colors[i], ax=ax)
 71        elif sdata.plotting.legend != None:
 72            if len(sdata.plotting.legend) == len(sdata.dfs):
 73                for i, df in enumerate(sdata.dfs):
 74                    if sdata.plotting.legend[i] == False:
 75                        continue  # Skip plots with False in the legend
 76                    clean_name = sdata.plotting.legend[i]
 77                    df.plot(x=df.columns[0], y=df.columns[1], color=colors[i], label=clean_name, ax=ax)
 78            elif len(sdata.plotting.legend) == 1:
 79                clean_name = sdata.plotting.legend[0]
 80                for i, df in enumerate(sdata.dfs):
 81                    df.plot(x=df.columns[0], y=df.columns[1], color=colors[i], label=clean_name, ax=ax)
 82        elif sdata.plotting.legend == None and len(sdata.files) == len(sdata.dfs):
 83            for df, name in zip(sdata.dfs, sdata.files):
 84                clean_name = name
 85                for string in strings_to_delete_from_name:
 86                    clean_name = clean_name.replace(string, '')
 87                clean_name = clean_name.replace('_', ' ')
 88                df.plot(x=df.columns[0], y=df.columns[1], color=colors[i], label=clean_name, ax=ax)
 89    # Matplotlib title and axis, additional margins
 90    plt.title(title)
 91    plt.xlabel(xlabel)
 92    plt.ylabel(ylabel)
 93    add_top = 0
 94    add_low = 0
 95    if hasattr(sdata, 'plotting'):
 96        if sdata.plotting.margins and isinstance(sdata.plotting.margins, list):
 97            add_low = sdata.plotting.margins[0]
 98            add_top = sdata.plotting.margins[1]
 99        if sdata.plotting.log_xscale:
100            ax.set_xscale('log')
101        if not sdata.plotting.show_yticks:
102            ax.set_yticks([])
103        if sdata.plotting.legend != False:
104            ax.legend(title=sdata.plotting.legend_title, fontsize=sdata.plotting.legend_size, loc=sdata.plotting.legend_loc)
105        else:
106            ax.legend().set_visible(False)
107    low_ylim = low_ylim - add_low
108    top_ylim = top_ylim + add_top
109    ax.set_ylim(bottom=low_ylim)
110    ax.set_ylim(top=top_ylim)
111    ax.set_xlim(left=low_xlim)
112    ax.set_xlim(right=top_xlim)
113    # Include optional lines
114    if hasattr(sdata, 'plotting') and sdata.plotting.vline is not None and sdata.plotting.vline_error is not None:
115        for vline, vline_error in zip(sdata.plotting.vline, sdata.plotting.vline_error):
116            lower_bound = vline - vline_error
117            upper_bound = vline + vline_error
118            ax.fill_between([lower_bound, upper_bound], low_ylim, top_ylim, color='gray', alpha=0.1)
119    elif hasattr(sdata, 'plotting') and sdata.plotting.vline is not None:
120        for vline in sdata.plotting.vline:
121            ax.axvline(x=vline, color='gray', alpha=0.5, linestyle='--')
122    # Save the file
123    if hasattr(sdata, 'plotting') and sdata.plotting.save_as:
124        root = os.getcwd()
125        save_name = os.path.join(root, sdata.plotting.save_as)
126        plt.savefig(save_name)
127    # Show the file
128    plt.show()
129
130
131def _get_ylimits(spectrum:Spectra) -> tuple[float, float]:
132    """Private function to obtain the ylimits to plot."""
133    # Optional scaling factor
134    scale_factor = spectrum.plotting.scaling if hasattr(spectrum, 'plotting') and spectrum.plotting.scaling else 1.0
135    # Get the Y limits
136    all_y_values = []
137    for df in spectrum.dfs:
138        df_trim = df
139        if hasattr(spectrum, 'plotting') and spectrum.plotting.xlim[0] is not None:
140            df_trim = df_trim[(df_trim[df_trim.columns[0]] >= spectrum.plotting.xlim[0])]
141        if hasattr(spectrum, 'plotting') and spectrum.plotting.xlim[1] is not None:
142            df_trim = df_trim[(df_trim[df_trim.columns[0]] <= spectrum.plotting.xlim[1])]
143        all_y_values.extend(df_trim[df_trim.columns[1]].tolist())
144    calculated_low_ylim = min(all_y_values)
145    calculated_top_ylim = max(all_y_values)
146    return calculated_low_ylim, calculated_top_ylim
def plot(spectra: sah.classes.Spectra):
 19def plot(spectra:Spectra):
 20    """Plots a `spectra`.
 21
 22    Optional `sah.classes.Plotting` attributes can be used.
 23    """
 24    # To clean the filename
 25    strings_to_delete_from_name = ['.csv', '.dat', '.txt', '_INS', '_ATR', '_FTIR', '_temp', '_RAMAN', '_Raman', '/data/', 'data/', '/csv/', 'csv/', '/INS/', 'INS/', '/FTIR/', 'FTIR/', '/ATR/', 'ATR/', '_smooth', '_smoothed', '_subtracted', '_cellsubtracted']
 26    # Avoid modifying the original Spectra object
 27    sdata = deepcopy(spectra)
 28    # Matplotlib stuff
 29    if hasattr(sdata, 'plotting') and sdata.plotting.figsize:
 30        fig, ax = plt.subplots(figsize=sdata.plotting.figsize)
 31    else:
 32        fig, ax = plt.subplots()
 33    # Optional scaling factor
 34    scale_factor = sdata.plotting.scaling if hasattr(sdata, 'plotting') and sdata.plotting.scaling else 1.0
 35    # Calculate Y limits
 36    calculated_low_ylim, calculated_top_ylim = _get_ylimits(sdata)
 37    low_ylim = calculated_low_ylim if not hasattr(sdata, 'plotting') or sdata.plotting.ylim[0] is None else sdata.plotting.ylim[0]
 38    top_ylim = calculated_top_ylim if not hasattr(sdata, 'plotting') or sdata.plotting.ylim[1] is None else sdata.plotting.ylim[1]
 39    # Get some plotting parameters
 40    low_xlim = None
 41    top_xlim = None
 42    if getattr(sdata, 'plotting', None) is not None:
 43        title = sdata.plotting.title
 44        low_xlim = sdata.plotting.xlim[0]
 45        top_xlim = sdata.plotting.xlim[1]
 46        xlabel = sdata.plotting.xlabel if sdata.plotting.xlabel is not None else sdata.dfs[0].columns[0]
 47        ylabel = sdata.plotting.ylabel if sdata.plotting.ylabel is not None else sdata.dfs[0].columns[1]
 48    else:
 49        title = sdata.comment
 50    # Set plot offset
 51    number_of_plots = len(sdata.dfs)
 52    height = (top_ylim - low_ylim)
 53    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
 54    if hasattr(sdata, 'plotting') and sdata.plotting.viridis:
 55        colors = plt.cm.viridis(np.linspace(0, 1, number_of_plots+1))  # +1 to avoid the lighter tones
 56    if hasattr(sdata, 'plotting') and sdata.plotting.offset is True:
 57        for i, df in enumerate(sdata.dfs):
 58            reverse_i = (number_of_plots - 1) - i
 59            df[df.columns[1]] = df[df.columns[1]] + (reverse_i * height / scale_factor)
 60    elif hasattr(sdata, 'plotting') and (isinstance(sdata.plotting.offset, float) or isinstance(sdata.plotting.offset, int)):
 61        offset = sdata.plotting.offset
 62        for i, df in enumerate(sdata.dfs):
 63            reverse_i = (number_of_plots - 1) - i
 64            df[df.columns[1]] = df[df.columns[1]] + (reverse_i * offset / scale_factor)
 65    _, calculated_top_ylim = _get_ylimits(sdata)
 66    top_ylim = calculated_top_ylim if not hasattr(sdata, 'plotting') or sdata.plotting.ylim[1] is None else sdata.plotting.ylim[1]
 67    # Set legend
 68    if hasattr(sdata, 'plotting') and hasattr(sdata.plotting, 'legend'):
 69        if sdata.plotting.legend == False:
 70            for df in sdata.dfs:
 71                df.plot(x=df.columns[0], y=df.columns[1], color=colors[i], ax=ax)
 72        elif sdata.plotting.legend != None:
 73            if len(sdata.plotting.legend) == len(sdata.dfs):
 74                for i, df in enumerate(sdata.dfs):
 75                    if sdata.plotting.legend[i] == False:
 76                        continue  # Skip plots with False in the legend
 77                    clean_name = sdata.plotting.legend[i]
 78                    df.plot(x=df.columns[0], y=df.columns[1], color=colors[i], label=clean_name, ax=ax)
 79            elif len(sdata.plotting.legend) == 1:
 80                clean_name = sdata.plotting.legend[0]
 81                for i, df in enumerate(sdata.dfs):
 82                    df.plot(x=df.columns[0], y=df.columns[1], color=colors[i], label=clean_name, ax=ax)
 83        elif sdata.plotting.legend == None and len(sdata.files) == len(sdata.dfs):
 84            for df, name in zip(sdata.dfs, sdata.files):
 85                clean_name = name
 86                for string in strings_to_delete_from_name:
 87                    clean_name = clean_name.replace(string, '')
 88                clean_name = clean_name.replace('_', ' ')
 89                df.plot(x=df.columns[0], y=df.columns[1], color=colors[i], label=clean_name, ax=ax)
 90    # Matplotlib title and axis, additional margins
 91    plt.title(title)
 92    plt.xlabel(xlabel)
 93    plt.ylabel(ylabel)
 94    add_top = 0
 95    add_low = 0
 96    if hasattr(sdata, 'plotting'):
 97        if sdata.plotting.margins and isinstance(sdata.plotting.margins, list):
 98            add_low = sdata.plotting.margins[0]
 99            add_top = sdata.plotting.margins[1]
100        if sdata.plotting.log_xscale:
101            ax.set_xscale('log')
102        if not sdata.plotting.show_yticks:
103            ax.set_yticks([])
104        if sdata.plotting.legend != False:
105            ax.legend(title=sdata.plotting.legend_title, fontsize=sdata.plotting.legend_size, loc=sdata.plotting.legend_loc)
106        else:
107            ax.legend().set_visible(False)
108    low_ylim = low_ylim - add_low
109    top_ylim = top_ylim + add_top
110    ax.set_ylim(bottom=low_ylim)
111    ax.set_ylim(top=top_ylim)
112    ax.set_xlim(left=low_xlim)
113    ax.set_xlim(right=top_xlim)
114    # Include optional lines
115    if hasattr(sdata, 'plotting') and sdata.plotting.vline is not None and sdata.plotting.vline_error is not None:
116        for vline, vline_error in zip(sdata.plotting.vline, sdata.plotting.vline_error):
117            lower_bound = vline - vline_error
118            upper_bound = vline + vline_error
119            ax.fill_between([lower_bound, upper_bound], low_ylim, top_ylim, color='gray', alpha=0.1)
120    elif hasattr(sdata, 'plotting') and sdata.plotting.vline is not None:
121        for vline in sdata.plotting.vline:
122            ax.axvline(x=vline, color='gray', alpha=0.5, linestyle='--')
123    # Save the file
124    if hasattr(sdata, 'plotting') and sdata.plotting.save_as:
125        root = os.getcwd()
126        save_name = os.path.join(root, sdata.plotting.save_as)
127        plt.savefig(save_name)
128    # Show the file
129    plt.show()

Plots a spectra.

Optional sah.classes.Plotting attributes can be used.