aton.spx.plot

Description

This module contains the plot() function, used to plot aton.spx.classes.Spectra data, containing optional aton.spx.classes.Plotting parameters.

It is used as aton.spx.plot(Spectra)


  1"""
  2# Description
  3
  4This module contains the `plot()` function,
  5used to plot `aton.spx.classes.Spectra` data,
  6containing optional `aton.spx.classes.Plotting` parameters.
  7
  8It is used as `aton.spx.plot(Spectra)`
  9
 10---
 11"""
 12
 13
 14import matplotlib.pyplot as plt
 15from .classes import *
 16
 17
 18def plot(spectra:Spectra):
 19    """Plots a `spectra`.
 20
 21    Optional `aton.spectra.classes.Plotting` attributes can be used.
 22    """
 23    # To clean the filename
 24    strings_to_delete_from_name = ['.csv', '.dat', '.txt', '_INS', '_ATR', '_FTIR', '_temp', '_RAMAN', '_Raman', '/data/', 'data/', '/csv/', 'csv/', '/INS/', 'INS/', '/FTIR/', 'FTIR/', '/ATR/', 'ATR/', '_smooth', '_smoothed', '_subtracted', '_cellsubtracted']
 25    # Avoid modifying the original Spectra object
 26    sdata = deepcopy(spectra)
 27    # Matplotlib stuff
 28    if hasattr(sdata, 'plotting') and sdata.plotting.figsize:
 29        fig, ax = plt.subplots(figsize=sdata.plotting.figsize)
 30    else:
 31        fig, ax = plt.subplots()
 32    # Optional scaling factor
 33    scale_factor = sdata.plotting.scaling if hasattr(sdata, 'plotting') and sdata.plotting.scaling else 1.0
 34    # Calculate Y limits
 35    calculated_low_ylim, calculated_top_ylim = _get_ylimits(sdata)
 36    low_ylim = calculated_low_ylim if not hasattr(sdata, 'plotting') or sdata.plotting.ylim[0] is None else sdata.plotting.ylim[0]
 37    top_ylim = calculated_top_ylim if not hasattr(sdata, 'plotting') or sdata.plotting.ylim[1] is None else sdata.plotting.ylim[1]
 38    # Get some plotting parameters
 39    low_xlim = None
 40    top_xlim = None
 41    if getattr(sdata, 'plotting', None) is not None:
 42        title = sdata.plotting.title
 43        low_xlim = sdata.plotting.xlim[0]
 44        top_xlim = sdata.plotting.xlim[1]
 45        xlabel = sdata.plotting.xlabel if sdata.plotting.xlabel is not None else sdata.dfs[0].columns[0]
 46        ylabel = sdata.plotting.ylabel if sdata.plotting.ylabel is not None else sdata.dfs[0].columns[1]
 47    else:
 48        title = sdata.comment
 49    # Set plot offset
 50    number_of_plots = len(sdata.dfs)
 51    height = (top_ylim - low_ylim)
 52    if hasattr(sdata, 'plotting') and sdata.plotting.offset is True:
 53        for i, df in enumerate(sdata.dfs):
 54            reverse_i = (number_of_plots - 1) - i
 55            df[df.columns[1]] = df[df.columns[1]] + (reverse_i * height / scale_factor)
 56    elif hasattr(sdata, 'plotting') and (isinstance(sdata.plotting.offset, float) or isinstance(sdata.plotting.offset, int)):
 57        offset = sdata.plotting.offset
 58        for i, df in enumerate(sdata.dfs):
 59            reverse_i = (number_of_plots - 1) - i
 60            df[df.columns[1]] = df[df.columns[1]] + (reverse_i * offset / scale_factor)
 61    _, calculated_top_ylim = _get_ylimits(sdata)
 62    top_ylim = calculated_top_ylim if not hasattr(sdata, 'plotting') or sdata.plotting.ylim[1] is None else sdata.plotting.ylim[1]
 63    # Set legend
 64    if hasattr(sdata, 'plotting') and hasattr(sdata.plotting, 'legend'):
 65        if sdata.plotting.legend == False:
 66            for df in sdata.dfs:
 67                df.plot(x=df.columns[0], y=df.columns[1], ax=ax)
 68        elif sdata.plotting.legend != None:
 69            if len(sdata.plotting.legend) == len(sdata.dfs):
 70                for i, df in enumerate(sdata.dfs):
 71                    if sdata.plotting.legend[i] == False:
 72                        continue  # Skip plots with False in the legend
 73                    clean_name = sdata.plotting.legend[i]
 74                    df.plot(x=df.columns[0], y=df.columns[1], label=clean_name, ax=ax)
 75            elif len(sdata.plotting.legend) == 1:
 76                clean_name = sdata.plotting.legend[0]
 77                for i, df in enumerate(sdata.dfs):
 78                    df.plot(x=df.columns[0], y=df.columns[1], label=clean_name, ax=ax)
 79        elif sdata.plotting.legend == None and len(sdata.files) == len(sdata.dfs):
 80            for df, name in zip(sdata.dfs, sdata.files):
 81                clean_name = name
 82                for string in strings_to_delete_from_name:
 83                    clean_name = clean_name.replace(string, '')
 84                clean_name = clean_name.replace('_', ' ')
 85                df.plot(x=df.columns[0], y=df.columns[1], label=clean_name, ax=ax)
 86    # Matplotlib title and axis, additional margins
 87    plt.title(title)
 88    plt.xlabel(xlabel)
 89    plt.ylabel(ylabel)
 90    add_top = 0
 91    add_low = 0
 92    if hasattr(sdata, 'plotting'):
 93        if sdata.plotting.margins and isinstance(sdata.plotting.margins, list):
 94            add_low = sdata.plotting.margins[0]
 95            add_top = sdata.plotting.margins[1]
 96        if sdata.plotting.log_xscale:
 97            ax.set_xscale('log')
 98        if not sdata.plotting.show_yticks:
 99            ax.set_yticks([])
100        if sdata.plotting.legend != False:
101            ax.legend(title=sdata.plotting.legend_title, fontsize=sdata.plotting.legend_size, loc=sdata.plotting.legend_loc)
102        else:
103            ax.legend().set_visible(False)
104    low_ylim = low_ylim - add_low
105    top_ylim = top_ylim + add_top
106    ax.set_ylim(bottom=low_ylim)
107    ax.set_ylim(top=top_ylim)
108    ax.set_xlim(left=low_xlim)
109    ax.set_xlim(right=top_xlim)
110    # Include optional lines
111    if hasattr(sdata, 'plotting') and sdata.plotting.vline is not None and sdata.plotting.vline_error is not None:
112        for vline, vline_error in zip(sdata.plotting.vline, sdata.plotting.vline_error):
113            lower_bound = vline - vline_error
114            upper_bound = vline + vline_error
115            ax.fill_between([lower_bound, upper_bound], low_ylim, top_ylim, color='gray', alpha=0.1)
116    elif hasattr(sdata, 'plotting') and sdata.plotting.vline is not None:
117        for vline in sdata.plotting.vline:
118            ax.axvline(x=vline, color='gray', alpha=0.5, linestyle='--')
119    # Save the file
120    if hasattr(sdata, 'plotting') and sdata.plotting.save_as:
121        root = os.getcwd()
122        save_name = os.path.join(root, sdata.plotting.save_as)
123        plt.savefig(save_name)
124    # Show the file
125    plt.show()
126
127
128def _get_ylimits(spectrum:Spectra) -> tuple[float, float]:
129    """Private function to obtain the ylimits to plot."""
130    # Optional scaling factor
131    scale_factor = spectrum.plotting.scaling if hasattr(spectrum, 'plotting') and spectrum.plotting.scaling else 1.0
132    # Get the Y limits
133    all_y_values = []
134    for df in spectrum.dfs:
135        df_trim = df
136        if hasattr(spectrum, 'plotting') and spectrum.plotting.xlim[0] is not None:
137            df_trim = df_trim[(df_trim[df_trim.columns[0]] >= spectrum.plotting.xlim[0])]
138        if hasattr(spectrum, 'plotting') and spectrum.plotting.xlim[1] is not None:
139            df_trim = df_trim[(df_trim[df_trim.columns[0]] <= spectrum.plotting.xlim[1])]
140        all_y_values.extend(df_trim[df_trim.columns[1]].tolist())
141    calculated_low_ylim = min(all_y_values)
142    calculated_top_ylim = max(all_y_values)
143    return calculated_low_ylim, calculated_top_ylim
def plot(spectra: aton.spx.classes.Spectra):
 19def plot(spectra:Spectra):
 20    """Plots a `spectra`.
 21
 22    Optional `aton.spectra.classes.Plotting` attributes can be used.
 23    """
 24    # To clean the filename
 25    strings_to_delete_from_name = ['.csv', '.dat', '.txt', '_INS', '_ATR', '_FTIR', '_temp', '_RAMAN', '_Raman', '/data/', 'data/', '/csv/', 'csv/', '/INS/', 'INS/', '/FTIR/', 'FTIR/', '/ATR/', 'ATR/', '_smooth', '_smoothed', '_subtracted', '_cellsubtracted']
 26    # Avoid modifying the original Spectra object
 27    sdata = deepcopy(spectra)
 28    # Matplotlib stuff
 29    if hasattr(sdata, 'plotting') and sdata.plotting.figsize:
 30        fig, ax = plt.subplots(figsize=sdata.plotting.figsize)
 31    else:
 32        fig, ax = plt.subplots()
 33    # Optional scaling factor
 34    scale_factor = sdata.plotting.scaling if hasattr(sdata, 'plotting') and sdata.plotting.scaling else 1.0
 35    # Calculate Y limits
 36    calculated_low_ylim, calculated_top_ylim = _get_ylimits(sdata)
 37    low_ylim = calculated_low_ylim if not hasattr(sdata, 'plotting') or sdata.plotting.ylim[0] is None else sdata.plotting.ylim[0]
 38    top_ylim = calculated_top_ylim if not hasattr(sdata, 'plotting') or sdata.plotting.ylim[1] is None else sdata.plotting.ylim[1]
 39    # Get some plotting parameters
 40    low_xlim = None
 41    top_xlim = None
 42    if getattr(sdata, 'plotting', None) is not None:
 43        title = sdata.plotting.title
 44        low_xlim = sdata.plotting.xlim[0]
 45        top_xlim = sdata.plotting.xlim[1]
 46        xlabel = sdata.plotting.xlabel if sdata.plotting.xlabel is not None else sdata.dfs[0].columns[0]
 47        ylabel = sdata.plotting.ylabel if sdata.plotting.ylabel is not None else sdata.dfs[0].columns[1]
 48    else:
 49        title = sdata.comment
 50    # Set plot offset
 51    number_of_plots = len(sdata.dfs)
 52    height = (top_ylim - low_ylim)
 53    if hasattr(sdata, 'plotting') and sdata.plotting.offset is True:
 54        for i, df in enumerate(sdata.dfs):
 55            reverse_i = (number_of_plots - 1) - i
 56            df[df.columns[1]] = df[df.columns[1]] + (reverse_i * height / scale_factor)
 57    elif hasattr(sdata, 'plotting') and (isinstance(sdata.plotting.offset, float) or isinstance(sdata.plotting.offset, int)):
 58        offset = sdata.plotting.offset
 59        for i, df in enumerate(sdata.dfs):
 60            reverse_i = (number_of_plots - 1) - i
 61            df[df.columns[1]] = df[df.columns[1]] + (reverse_i * offset / scale_factor)
 62    _, calculated_top_ylim = _get_ylimits(sdata)
 63    top_ylim = calculated_top_ylim if not hasattr(sdata, 'plotting') or sdata.plotting.ylim[1] is None else sdata.plotting.ylim[1]
 64    # Set legend
 65    if hasattr(sdata, 'plotting') and hasattr(sdata.plotting, 'legend'):
 66        if sdata.plotting.legend == False:
 67            for df in sdata.dfs:
 68                df.plot(x=df.columns[0], y=df.columns[1], ax=ax)
 69        elif sdata.plotting.legend != None:
 70            if len(sdata.plotting.legend) == len(sdata.dfs):
 71                for i, df in enumerate(sdata.dfs):
 72                    if sdata.plotting.legend[i] == False:
 73                        continue  # Skip plots with False in the legend
 74                    clean_name = sdata.plotting.legend[i]
 75                    df.plot(x=df.columns[0], y=df.columns[1], label=clean_name, ax=ax)
 76            elif len(sdata.plotting.legend) == 1:
 77                clean_name = sdata.plotting.legend[0]
 78                for i, df in enumerate(sdata.dfs):
 79                    df.plot(x=df.columns[0], y=df.columns[1], label=clean_name, ax=ax)
 80        elif sdata.plotting.legend == None and len(sdata.files) == len(sdata.dfs):
 81            for df, name in zip(sdata.dfs, sdata.files):
 82                clean_name = name
 83                for string in strings_to_delete_from_name:
 84                    clean_name = clean_name.replace(string, '')
 85                clean_name = clean_name.replace('_', ' ')
 86                df.plot(x=df.columns[0], y=df.columns[1], label=clean_name, ax=ax)
 87    # Matplotlib title and axis, additional margins
 88    plt.title(title)
 89    plt.xlabel(xlabel)
 90    plt.ylabel(ylabel)
 91    add_top = 0
 92    add_low = 0
 93    if hasattr(sdata, 'plotting'):
 94        if sdata.plotting.margins and isinstance(sdata.plotting.margins, list):
 95            add_low = sdata.plotting.margins[0]
 96            add_top = sdata.plotting.margins[1]
 97        if sdata.plotting.log_xscale:
 98            ax.set_xscale('log')
 99        if not sdata.plotting.show_yticks:
100            ax.set_yticks([])
101        if sdata.plotting.legend != False:
102            ax.legend(title=sdata.plotting.legend_title, fontsize=sdata.plotting.legend_size, loc=sdata.plotting.legend_loc)
103        else:
104            ax.legend().set_visible(False)
105    low_ylim = low_ylim - add_low
106    top_ylim = top_ylim + add_top
107    ax.set_ylim(bottom=low_ylim)
108    ax.set_ylim(top=top_ylim)
109    ax.set_xlim(left=low_xlim)
110    ax.set_xlim(right=top_xlim)
111    # Include optional lines
112    if hasattr(sdata, 'plotting') and sdata.plotting.vline is not None and sdata.plotting.vline_error is not None:
113        for vline, vline_error in zip(sdata.plotting.vline, sdata.plotting.vline_error):
114            lower_bound = vline - vline_error
115            upper_bound = vline + vline_error
116            ax.fill_between([lower_bound, upper_bound], low_ylim, top_ylim, color='gray', alpha=0.1)
117    elif hasattr(sdata, 'plotting') and sdata.plotting.vline is not None:
118        for vline in sdata.plotting.vline:
119            ax.axvline(x=vline, color='gray', alpha=0.5, linestyle='--')
120    # Save the file
121    if hasattr(sdata, 'plotting') and sdata.plotting.save_as:
122        root = os.getcwd()
123        save_name = os.path.join(root, sdata.plotting.save_as)
124        plt.savefig(save_name)
125    # Show the file
126    plt.show()

Plots a spectra.

Optional aton.spectra.classes.Plotting attributes can be used.