"""
Contains the definition of SAUL's :class:`Spectrogram` class.
"""
import copy
from functools import cache
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec
from multitaper import mtspec
from scipy.signal import spectrogram
from stockwell import st as _st # Avoid conflict with ObsPy `st`
from saul.spectral.helpers import (
_FREQ_TEMPLATE,
CYCLES_PER_WINDOW,
_format_power_label,
_get_db_reference_value,
)
from saul.waveform.helpers import _num2date
from saul.waveform.stream import Stream
from saul.waveform.units import _validate_provided_vs_inferred_units, get_waveform_units
[docs]
class Spectrogram:
"""A class for calculating and plotting spectrograms of waveforms.
Attributes
==========
Attributes:
method (str): See :meth:`__init__`
win_dur (int or float): See :meth:`__init__`; only defined if ``method='scipy'``
or ``method='multitaper'``
time_bandwidth_product (float): See :meth:`__init__`; only defined if
``method='multitaper'``
number_of_tapers (int): See :meth:`__init__`; only defined if
``method='multitaper'``
gamma (float): See :meth:`__init__`; only defined if ``method='s_transform'``
max_fs (int or float): See :meth:`__init__`; only defined if
``method='s_transform'``
tr (:class:`~obspy.core.trace.Trace`): Input waveform
data_kind (str): Input waveform data kind; e.g., ``'infrasound'`` or
``'seismic'`` (inferred from channel code)
db_ref_val (int, float, or None): dB reference value for PSD (data kind
dependent)
waveform_units (str or None): Units of the input waveform
spectrogram (tuple): Spectrogram (in dB) calculated from the input waveform; of
the form ``(f, t, sxx_db)`` where ``f`` and ``t`` are 1D arrays and
``sxx_db`` is a 2D array with shape ``(f.size, t.size)``
Methods
=======
"""
[docs]
def __init__(
self,
tr_or_st,
method='scipy',
win_dur=8,
time_bandwidth_product=4,
number_of_tapers=7,
gamma=1,
max_fs=10,
units='infer',
):
"""Create a :class:`Spectrogram` object.
The spectrogram of the input waveform is estimated in this method (only a single
waveform may be provided). Three spectral estimation approaches are supported:
The method implemented by SciPy (:func:`scipy.signal.spectrogram`), the
multitaper method (:func:`mtspec.spectrogram`), and the :math:`S` transform
(implemented in the `Stockwell <https://github.com/claudiodsf/stockwell>`_
package). The input arguments (below) relevant for each method are marked with a
**[P]** for the SciPy method, an **[M]** for the multitaper method, and an
**[S]** for the :math:`S` transform. Arguments corresponding to the non-selected
method are ignored.
Args:
tr_or_st (:class:`~obspy.core.trace.Trace` or :class:`~saul.waveform.stream.Stream`):
Input waveform
method (str): Either ``'scipy'`` **[P]**, ``'multitaper'`` **[M]**, or
``'s_transform'`` **[S]**
win_dur (int or float): **[P]** **[M]** Segment length in seconds. This
usually must be adjusted, within the constraints of the total signal
duration, to ensure that the longest-period signals of interest are
included
time_bandwidth_product (float): **[M]** Time-bandwidth product
number_of_tapers (int): **[M]** Number of tapers to use
gamma (float): **[S]** Gamma parameter, see `here
<https://github.com/claudiodsf/stockwell/blob/31e1db400aaa0b61c4df8b8ec30f9e58731f611e/stockwell/st.py#L76-L82>`_
for more info
max_fs (int or float): **[S]** Maximum allowed sampling rate in hertz. If an
input signal has a sampling rate higher than this, it will be
downsampled before the :math:`S` transform is computed (this saves
computation time and memory)
units (str or None): Units of the input waveform; either ``'infer'`` to
guess from input response information, a string explicitly defining the
units (see ``_VALID_UNIT_OPTIONS`` in :mod:`saul.waveform.units` for
supported options), or ``None`` for unknown units (e.g., counts)
"""
# Pre-processing and checks
msg = 'Method must be either \'scipy\', \'multitaper\', or \'s_transform\''
assert method in ['scipy', 'multitaper', 's_transform'], msg
self.method = method
match self.method:
case 'scipy':
self.win_dur = win_dur
case 'multitaper':
self.win_dur = win_dur
self.time_bandwidth_product = time_bandwidth_product
self.number_of_tapers = number_of_tapers
case 's_transform':
self.gamma = gamma
self.max_fs = max_fs
st = Stream(tr_or_st) # Cast input to saul.Stream
assert st.count() > 0, 'No waveform provided!'
assert st.count() == 1, 'Must provide only a single waveform!'
self.tr = st[0].copy() # Always use *copied* saul.Stream objects
# Handle data kind, units, and reference dB value
data_kind, inferred_units = get_waveform_units(self.tr)
self.data_kind = data_kind
self.db_ref_val = _get_db_reference_value(self.data_kind)
self.waveform_units = _validate_provided_vs_inferred_units(
units, inferred_units, self.data_kind
)
if self.waveform_units is None:
self.db_ref_val = None # Reference value is meaningless w/o units!
# KEY: Calculate spectrogram (in dB relative to self.db_ref_val)
match self.method:
case 'scipy':
fs = self.tr.stats.sampling_rate
nperseg = int(win_dur * fs) # Samples
nfft = np.power(2, int(np.ceil(np.log2(nperseg))) + 1) # Pad FFT
f, t, sxx = spectrogram(
self.tr.data,
fs,
window='hann',
nperseg=nperseg,
noverlap=nperseg // 2, # 50 % overlap
nfft=nfft,
)
case 'multitaper':
t, f, _, sxx = self._spectrogram(
tuple(self.tr.data),
dt=self.tr.stats.delta,
twin=win_dur,
nw=time_bandwidth_product,
kspec=number_of_tapers,
olap=0.5, # 50 % overlap
iadapt=0, # "Adaptive multitaper" <- change?
)
f = f.squeeze()
t += win_dur / 2 # Make t vector *centered* in each time window
case 's_transform':
if self.tr.stats.sampling_rate > max_fs:
print(f'Downsampling data to {max_fs} Hz for S transform')
_tr = self.tr.copy()
_tr.filter('lowpass_cheby_2', freq=max_fs / 2)
_tr.interpolate(max_fs, method='lanczos', a=20)
else:
_tr = self.tr
f = np.linspace(0, _tr.stats.sampling_rate / 2, _tr.stats.npts // 2)
t = _tr.times()
_sxx = _st.st(
_tr.data, lo=0, hi=f.size - 1, gamma=gamma, win_type='gauss'
)
sxx = np.abs(_sxx) ** 2 # TODO: Convert to power? What about density?
f, sxx = f[1:], sxx[1:, :] # Remove DC component
# Convert to dB [dB rel. (db_ref_val <db_ref_val_unit>)^2 Hz^-1]
if self.db_ref_val is None:
denominator = sxx.max() # 0 dB is max
else:
denominator = self.db_ref_val**2
sxx_db = 10 * np.log10(sxx / denominator)
self.spectrogram = (f, t, sxx_db)
[docs]
def plot(
self,
db_lim='smart',
use_period=False,
log_y=False,
):
"""Plot the calculated spectrogram.
Args:
db_lim (tuple, str, or None): Tuple defining min and max dB cutoffs,
``'smart'`` for a sensible automatic choice, or ``None`` for no clipping
use_period (bool): If ``True``, spectrogram *y*-axis will be period [s]
instead of frequency [Hz]
log_y (bool): If ``True``, use log scaling for spectrogram *y*-axis
"""
assert not (use_period and not log_y), 'Cannot use period with linear y-scale!'
fig = plt.figure(figsize=(7, 5))
# width_ratios effectively controls the colorbar width
gs = GridSpec(2, 2, figure=fig, height_ratios=[2, 1], width_ratios=[40, 1])
# Set up the three required axes
spec_ax = fig.add_subplot(gs[0, 0])
wf_ax = fig.add_subplot(gs[1, 0], sharex=spec_ax) # Common time axis
cax = fig.add_subplot(gs[0, 1])
if self.data_kind == 'seismic' and self.waveform_units is not None:
rescale = 1e6 # Use μ prefix for seismic, unless we have unknown units
else:
rescale = 1
wf_ax.plot(
self.tr.times('matplotlib'), self.tr.data * rescale, 'black', linewidth=0.5
)
match self.waveform_units:
case 'pa':
ylabel = 'Pressure (Pa)'
yunit = 'Pa'
case 'm':
ylabel = 'Displacement (μm)'
yunit = 'μm'
case 'm/s':
ylabel = r'Velocity (μm s$\mathdefault{^{-1}}$)'
yunit = 'μm/s'
case 'm/s**2':
ylabel = r'Acceleration (μm s$\mathdefault{^{-2}}$)'
yunit = 'μm/s²'
case None:
ylabel = 'Amplitude'
yunit = 'unknown units'
case _:
raise ValueError(f'Invalid units: {self.waveform_units}')
wf_ax.set_ylabel(ylabel)
wf_ax.grid(linestyle=':', zorder=-5)
f, t, sxx_db = self.spectrogram
t_mpl = self.tr.stats.starttime.matplotlib_date + (t / mdates.SEC_PER_DAY)
x = t_mpl
dx = np.diff(x)[0]
y = f
dy = np.diff(y)[0]
im = spec_ax.imshow(
sxx_db,
cmap='magma',
interpolation='none',
rasterized=True,
aspect='auto',
origin='lower',
extent=( # Carefully handling the registration here
x.min() - dx / 2,
x.max() + dx / 2,
y.min() - dy / 2,
y.max() + dy / 2,
),
)
if use_period:
grid_axis = 'x' # Just place horizontal gridlines; we'll add vertical later
else:
grid_axis = 'both'
spec_ax.set_ylabel('Frequency (Hz)') # Go ahead and set this now
spec_ax.grid(linestyle=':', zorder=5, axis=grid_axis)
spec_ax.set_facecolor(plt.rcParams['grid.color'])
if self.method == 's_transform':
fmin = f.min() # [Hz] S transform doesn't have frequency resolution limits
else:
fmin = 1 / (self.win_dur / CYCLES_PER_WINDOW) # [Hz] Min. resolvable freq.
fmax = self.tr.stats.sampling_rate / 2 # [Hz] Nyquist
spec_ax.set_ylim(fmin, fmax)
if log_y:
spec_ax.set_yscale('log')
if use_period: # Overcome imshow() limitations by defining an axis overlay
# Set up overlay and scale it properly
spec_ax_overlay = spec_ax.twinx()
spec_ax_overlay.set_zorder(spec_ax.get_zorder() - 1) # Place below spec
spec_ax_overlay.set_ylim(1 / fmin, 1 / fmax)
spec_ax_overlay.set_yscale('log') # log_y is guaranteed to be True
# Remove the ticks and labels from the underlying plot
spec_ax.tick_params(axis='y', which='both', left=False, labelleft=False)
# Configure ticks and axis labels for the overlay
spec_ax_overlay.yaxis.tick_left()
spec_ax_overlay.set_ylabel('Period (s)')
spec_ax_overlay.yaxis.set_label_position('left')
# Finally, we add the y-axis grid (to the overlay to ensure correct ticking)
spec_ax_overlay.grid(linestyle=':', axis='y')
wf_ax.set_xlim(
self.tr.stats.starttime.matplotlib_date,
self.tr.stats.endtime.matplotlib_date,
)
# Tick locating and formatting
locator = mdates.AutoDateLocator()
wf_ax.xaxis.set_major_locator(locator)
formatter = mdates.AutoDateFormatter(locator)
formatter.scaled[30.0] = '%b. %Y'
formatter.scaled[1] = '%-d %b. %Y'
formatter.scaled[1 / mdates.HOURS_PER_DAY] = '%H:%M'
formatter.scaled[1 / mdates.MINUTES_PER_DAY] = '%H:%M'
formatter.scaled[1 / mdates.MUSECONDS_PER_DAY] = '%H:%M:%S.%f'
wf_ax.xaxis.set_major_formatter(formatter)
fig.autofmt_xdate()
start_date = self.tr.stats.starttime.strftime('%-d %B %Y')
wf_ax.set_xlabel(f'UTC time starting on {start_date}')
# Pick smart limits rounded to nearest 10 dB
if db_lim == 'smart':
db_min = np.percentile(sxx_db, 20)
db_max = sxx_db.max()
db_lim = np.ceil(db_min / 10) * 10, np.floor(db_max / 10) * 10
im.set_clim(db_lim)
# Automatically determine whether to show triangle extensions on colorbar (kind
# of adopted from xarray)
if db_lim:
min_extend = sxx_db.min() < db_lim[0]
max_extend = sxx_db.max() > db_lim[1]
else:
min_extend = False
max_extend = False
if min_extend and max_extend:
extend = 'both'
elif min_extend:
extend = 'min'
elif max_extend:
extend = 'max'
else:
extend = 'neither'
extendfrac = 0.04
fig.colorbar(
im,
cax,
extend=extend,
extendfrac=extendfrac,
label=_format_power_label(self.db_ref_val, self.waveform_units),
)
spec_ax.set_title(self.tr.id, family='monospace')
# Layout adjustment
gs.tight_layout(fig)
gs.update(hspace=0.1, wspace=0.07)
# Finnicky formatting to get extension triangles (if they exist) to extend above
# and below the vertical extent of the spectrogram axes
pos = cax.get_position()
triangle_height = extendfrac * pos.height
ymin = pos.ymin
height = pos.height
if min_extend and max_extend:
ymin -= triangle_height
height += 2 * triangle_height
elif min_extend and not max_extend:
ymin -= triangle_height
height += triangle_height
elif max_extend and not min_extend:
height += triangle_height
cax.set_position([pos.xmin, ymin, pos.width, height])
# Cursor formatting
spec_ax.format_coord = (
lambda x, y: f'({_num2date(x)}, {formatter.fix_minus(_FREQ_TEMPLATE.format(y, 1 / y))})'
)
im.format_cursor_data = lambda data: formatter.fix_minus(f'{data:.1f} dB')
wf_ax.format_coord = (
lambda x, y: f'({_num2date(x)}, {formatter.fix_minus(f"{y:.2g}")} {yunit})'
)
cax.format_coord = lambda x, y: '' # Disable colorbar cursor info
fig.show()
[docs]
def copy(self):
"""Return a deep copy of the :class:`Spectrogram` object."""
return copy.deepcopy(self)
@staticmethod
@cache
def _spectrogram(tr_data_tuple, **kwargs):
"""Wrapper around :func:`mtspec.spectrogram` to facilitate tuple input (needed for memoization).
Warning:
For large input arrays (many samples), conversion to tuple and then back to
:class:`numpy.ndarray` can be **slow**. In this case, memoization may not be
worth it.
"""
return mtspec.spectrogram(np.array(tr_data_tuple), **kwargs)