"""
Contains the definition of SAUL's :class:`PSD` class.
"""
import copy
from functools import cache
import matplotlib.pyplot as plt
import numpy as np
from esi_core.gmprocess.waveform_processing.smoothing.konno_ohmachi import (
konno_ohmachi_smooth,
)
from matplotlib.ticker import Formatter
from multitaper import mtspec
from obspy.signal.spectral_estimation import (
get_idc_infra_hi_noise,
get_idc_infra_low_noise,
get_nhnm,
get_nlnm,
)
from scipy.fft import next_fast_len
from scipy.signal import welch
from saul.spectral.helpers import (
_FREQ_TEMPLATE,
CYCLES_PER_WINDOW,
_format_power_label,
_get_db_reference_value,
get_ak_infra_noise,
)
from saul.waveform.stream import Stream
from saul.waveform.units import _validate_provided_vs_inferred_units, get_waveform_units
[docs]
class PSD:
"""A class for calculating and plotting PSDs of one or more waveforms.
Attributes
==========
Attributes:
method (str): See :meth:`__init__`
win_dur (int or float): See :meth:`__init__`; only defined if ``method='welch'``
time_bandwidth_product (float): See :meth:`__init__`; only defined if
``method='multitaper'``
number_of_tapers (int): See :meth:`__init__`; only defined if
``method='multitaper'``
st (SAUL :class:`~saul.waveform.stream.Stream`): Input waveforms (single
:class:`~obspy.core.trace.Trace` input is converted to SAUL
:class:`~saul.waveform.stream.Stream`)
data_kind (str): Input waveform data kind; e.g., ``'infrasound'`` or
``'seismic'`` (inferred from channel code)
db_ref_val (int, float, or None): dB reference value for PSD (data kind
dependent)
waveform_units (str or None): Units of the input waveforms
psd (list): List of PSDs (in dB) calculated from input waveforms; of the form
``[(f1, pxx_db1), (f2, pxx_db2), ...]`` given a
:class:`~saul.waveform.stream.Stream` consisting of
:class:`~obspy.core.trace.Trace` entries ``[tr1, tr2, ...]``
Methods
=======
"""
[docs]
def __init__(
self,
tr_or_st,
method='welch',
win_dur=60,
time_bandwidth_product=4,
number_of_tapers=7,
units='infer',
):
"""Create a :class:`PSD` object.
The PSDs of the input waveforms are estimated in this method. Two spectral
estimation approaches are supported: Welch's method (:func:`scipy.signal.welch`)
and the multitaper method (:class:`mtspec.MTSpec`). The input arguments (below)
relevant for each method are marked with a **[W]** for Welch's method and an
**[M]** for the multitaper method. Arguments corresponding to the non-selected
method are ignored.
Args:
tr_or_st (:class:`~obspy.core.trace.Trace` or :class:`~saul.waveform.stream.Stream`):
Input waveforms
method (str): Either ``'welch'`` **[W]** or ``'multitaper'`` **[M]**
win_dur (int or float): **[W]** Segment length in seconds. This usually must
be tweaked to obtain the cleanest-looking plot and to ensure that the
longest-period signals of interest are included
time_bandwidth_product (float): **[M]** Time-bandwidth product
number_of_tapers (int): **[M]** Number of tapers to use
units (str or None): Units of the input waveforms; either ``'infer'`` to
guess from input response information, a string explicitly defining the
units (see ``_VALID_UNIT_OPTIONS`` in :mod:`saul.waveform.units` for
supported options), or ``None`` for unknown units (e.g., counts) — all
input waveforms must have the same units!
"""
# Pre-processing and checks
assert method in [
'welch',
'multitaper',
], 'Method must be either \'welch\' or \'multitaper\''
self.method = method
if method == 'welch':
self.win_dur = win_dur
else: # self.method == 'multitaper'
self.time_bandwidth_product = time_bandwidth_product
self.number_of_tapers = number_of_tapers
self.st = Stream(tr_or_st).copy() # Always use *copied* saul.Stream objects
assert self.st.count() > 0, 'No waveforms provided!'
# Handle data kind, units, and reference dB value
data_kind_st, inferred_units_st = zip(
*[get_waveform_units(tr) for tr in self.st]
)
data_kind_unique = list(set(data_kind_st))
msg = 'Input waveforms have mixed data kinds — not supported!'
assert len(data_kind_unique) == 1, msg # Do all waveforms have same data kind?
inferred_units_unique = list(set(inferred_units_st))
msg = 'Input waveforms have mixed units — not supported!'
assert len(inferred_units_unique) == 1, msg # Do all waveforms have same units?
self.data_kind = data_kind_unique[0]
self.db_ref_val = _get_db_reference_value(self.data_kind)
self.waveform_units = _validate_provided_vs_inferred_units(
units, inferred_units_unique[0], self.data_kind
)
if self.waveform_units is None:
self.db_ref_val = None # Reference value is meaningless w/o units!
# KEY: Calculate PSD (in dB relative to self.db_ref_val)
self.psd = []
for tr in self.st:
if method == 'welch':
fs = tr.stats.sampling_rate
nperseg = int(win_dur * fs) # Samples
nfft = np.power(2, int(np.ceil(np.log2(nperseg))) + 1) # Pad FFT
f, pxx = welch(tr.data, fs, nperseg=nperseg, nfft=nfft)
else: # method == 'multitaper'
mtspec = self._mtspec(
tuple(tr.data),
nw=time_bandwidth_product,
kspec=number_of_tapers, # After a certain point this saturates
dt=tr.stats.delta,
nfft=next_fast_len(tr.stats.npts),
)
f, pxx = mtspec.rspec()
f, pxx = f.squeeze(), pxx.squeeze()
f, pxx = f[1:], pxx[1:] # Remove DC component
self.psd.append((f, pxx))
# Convert to dB [dB rel. (db_ref_val <db_ref_val_unit>)^2 Hz^-1]
if self.db_ref_val is None:
denominator = max([pxx.max() for _, pxx in self.psd]) # 0 dB for global max
else:
denominator = self.db_ref_val**2
for i, (f, pxx) in enumerate(self.psd):
pxx_db = 10 * np.log10(pxx / denominator)
self.psd[i] = (f, pxx_db)
[docs]
def plot(
self,
db_lim='smart',
use_period=False,
log_x=True,
show_noise_models=False,
infra_noise_model='ak',
):
"""Plot the calculated PSDs.
Args:
db_lim (tuple, str, or None): Tuple defining min and max dB cutoffs,
``'smart'`` for a sensible automatic choice, or ``None`` for no clipping
use_period (bool): If ``True``, *x*-axis will be period [s] instead of
frequency [Hz]
log_x (bool): If ``True``, use log scaling for *x*-axis
show_noise_models (bool): Whether to plot reference noise models
infra_noise_model (str): Which infrasound noise model to use (only used if
``show_noise_models`` is ``True`` and ``self.data_kind`` is
``'infrasound'``), one of ``'ak'`` (Alaska noise model) or ``'idc'``
(IMS array noise model)
"""
assert not (use_period and not log_x), 'Cannot use period with linear x-scale!'
assert infra_noise_model in [
'ak',
'idc',
], 'Infrasound noise model must be either \'ak\' or \'idc\''
fig, ax = plt.subplots()
for tr, (f, pxx_db) in zip(self.st, self.psd):
ax.plot(1 / f if use_period else f, pxx_db, label=tr.id)
if log_x:
ax.set_xscale('log')
if show_noise_models:
if self.waveform_units is None:
msg = 'Can\'t show noise models if waveform units are unknown!'
raise ValueError(msg)
match self.data_kind:
case 'infrasound':
if infra_noise_model == 'ak':
period, *nms = get_ak_infra_noise()
noise_models = [(period, nm) for nm in nms]
else: # infra_noise_model == 'idc':
noise_models = [
get_idc_infra_low_noise(),
get_idc_infra_hi_noise(),
]
# These are all given relative to 1 Pa -> need to convert to ref_val
for i, noise_model in enumerate(noise_models):
period, pxx_db_rel_1_pa = noise_model
pxx_db_rel_ref_val = pxx_db_rel_1_pa - 10 * np.log10(
self.db_ref_val**2
)
noise_models[i] = period, pxx_db_rel_ref_val
case 'seismic':
noise_models = [get_nlnm(), get_nhnm()]
# These are in units of acceleration, so we might need to convert
# them; see Table 3 in Peterson (1993)
# https://pubs.usgs.gov/of/1993/0322/ofr93-322.pdf
match self.waveform_units:
case 'm':
for i, noise_model in enumerate(noise_models):
period, pxx_db_acc = noise_model
pxx_db_disp = pxx_db_acc + 20.0 * np.log10(
period**2 / (4 * np.pi**2)
)
noise_models[i] = period, pxx_db_disp
case 'm/s':
for i, noise_model in enumerate(noise_models):
period, pxx_db_acc = noise_model
pxx_db_vel = pxx_db_acc + 20.0 * np.log10(
period / (2 * np.pi)
)
noise_models[i] = period, pxx_db_vel
case 'm/s**2':
pass # No conversion needed
case _:
msg = (
f'Invalid seismic waveform units: {self.waveform_units}'
)
raise ValueError(msg)
case _:
raise ValueError(f'No noise models for data kind: {self.data_kind}')
xlim, ylim = ax.get_xlim(), ax.get_ylim() # Store these to restore limits
for i, noise_model in enumerate(noise_models):
period, pxx_db = noise_model
ax.plot(
period if use_period else 1 / period,
pxx_db,
color='tab:gray',
linestyle=':',
zorder=-5,
label='Noise model' if not i else None, # Only label one line
)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
legend = ax.legend(draggable=True)
# For every ID in the legend, use monospace font (ignore noise model label!)
for label in legend.get_texts()[: len(self.psd)]:
label.set_family('monospace')
if self.method == 'welch':
fmin = 1 / (self.win_dur / CYCLES_PER_WINDOW) # [Hz] Min. resolvable freq.
else: # self.method == 'multitaper'
fmin = np.min([f for f, _ in self.psd]) # [Hz] Show the full PSD... bad?
fmax = max([tr.stats.sampling_rate for tr in self.st]) / 2 # [Hz] Max. Nyquist
if use_period:
xlabel = 'Period (s)'
ax.set_xlim(1 / fmax, 1 / fmin) # Follow convention (increasing period)
else:
xlabel = 'Frequency (Hz)'
ax.set_xlim(fmin, fmax)
# Pick smart limits "ceiled" to nearest 10 dB
if db_lim == 'smart':
pxx_db_all = []
for _, pxx_db in self.psd:
pxx_db_all += pxx_db.tolist()
db_min = np.percentile(pxx_db_all, 5) # Percentile across all PSDs
db_max = np.max(pxx_db_all) # Max value across all PSDs
db_lim = np.ceil(db_min / 10) * 10, np.ceil(db_max / 10) * 10
ax.set_ylim(db_lim)
ax.set_xlabel(xlabel)
ax.set_ylabel(_format_power_label(self.db_ref_val, self.waveform_units))
ax.format_coord = lambda x, y: Formatter.fix_minus(
f'({_FREQ_TEMPLATE.format(*(1 / x, x) if use_period else (x, 1 / x))}, {y:.1f} dB)'
)
fig.tight_layout()
fig.show()
[docs]
def smooth(self, bandwidth):
"""Smooth the calculated PSDs via the Konno–Ohmachi method.
The Konno–Ohmachi method smooths PSDs using fixed-bandwith windows. The C code
used by this method is
`here <https://code.usgs.gov/ghsc/esi/esi-core/-/blob/main/src/esi_core/gmprocess/waveform_processing/smoothing/smoothing.c>`_.
The
`ObsPy documentation <https://docs.obspy.org/packages/autogen/obspy.signal.konnoohmachismoothing.konno_ohmachi_smoothing.html>`_
for a similar function may also be helpful.
For more information, see equation 4 in Konno and Ohmachi (1998) — the :math:`b`
in that equation is the ``bandwidth`` parameter here.
Konno, K., & Ohmachi, T. (1998). Ground-motion characteristics estimated
from spectral ratio between horizontal and vertical components of
microtremor. *Bulletin of the Seismological Society of America*, *88*\ (1),
228–241. https://doi.org/10.1785/BSSA0880010228
Note:
The smoothing is performed in-place on the existing spectra in this object!
Args:
bandwidth (int or float): Bandwidth for smoothing — lower values produce a
broader smoothing effect
"""
for f, pxx_db in self.psd:
konno_ohmachi_smooth(
spec=pxx_db,
freqs=f,
ko_freqs=f,
spec_smooth=pxx_db,
bandwidth=bandwidth,
)
return self
[docs]
def copy(self):
"""Return a deep copy of the :class:`PSD` object."""
return copy.deepcopy(self)
@staticmethod
@cache
def _mtspec(tr_data_tuple, **kwargs):
"""Wrapper around :class:`mtspec.MTSpec` to facilitate tuple input (needed for memoization).
Warning:
For large input arrays (many samples), conversion to tuple and then back to
:class:`numpy.ndarray` can be **slow**. In this case, memoization may not be
worth it.
"""
return mtspec.MTSpec(np.array(tr_data_tuple), **kwargs)