"""Tools to generate and identify spacers.
Spacers are sequences of up and down pulses with a specific, identifiable pattern.
They are generated with a chirp coding to reduce cross-correlaation sidelobes.
They are used to mark the beginning of a behaviour sequence within a session.
Example
-------
>>> spacer = Spacer()
>>> spacer.add_spacer_states(sma, t, next_state='first_state')
>>> for i in range(ntrials):
... sma.add_state(
... state_name='first_state',
... state_timer=tup,
... state_change_conditions={'Tup': f'spacer_low_{i:02d}'},
... output_actions=[('BNC1', 255)], # To FPGA
... )
"""
import numpy as np
from scipy import signal
[docs]
class Spacer:
def __init__(self, dt_start=0.02, dt_end=0.4, n_pulses=8, tup=0.05):
"""Computes spacer up times using a chirp up and down pattern.
Parameters
----------
dt_start : float
First spacer up time.
dt_end : float
Last spacer up time.
n_pulses : int
Number of spacer up times, one-sided (i.e. 8 means 16 - 1 spacers times)
tup: float
Duration of the spacer up time.
"""
self.dt_start = dt_start
self.dt_end = dt_end
self.n_pulses = n_pulses
self.tup = tup
assert np.all(np.diff(self.times) > self.tup), 'Spacers are overlapping'
def __repr__(self):
return f'Spacer(dt_start={self.dt_start}, dt_end={self.dt_end}, n_pulses={self.n_pulses}, tup={self.tup})'
@property
def times(self):
"""Computes spacer up times using a chirp up and down pattern.
Each time corresponds to an up time of the BNC1 signal.
Returns
-------
numpy.array
Numpy arrays of spacer times.
"""
# upsweep
t = np.linspace(self.dt_start, self.dt_end, self.n_pulses) + self.tup
# downsweep
t = np.r_[t, np.flipud(t[1:])]
t = np.cumsum(t)
return t
[docs]
def generate_template(self, fs=1000):
"""
Generates a spacer voltage template to cross-correlate with a voltage trace from a DAQ to
detect a voltage trace.
Parameters
----------
fs : int
DAQ sampling frequency.
Returns
-------
numpy.array
The template spacer signal.
"""
t = self.times
ns = int((t[-1] + self.tup * 10) * fs)
sig = np.zeros(
ns,
)
sig[(t * fs).astype(np.int32)] = 1
sig[((t + self.tup) * fs).astype(np.int32)] = -1
sig = np.cumsum(sig)
return sig
[docs]
def add_spacer_states(self, sma=None, next_state='exit'):
"""
Add spacer states to a state machine.
Parameters
----------
sma : pybpodapi.state_machine.StateMachine
A Bpod state machine instance.
next_state : str
The name of the state to follow the spacer state.
"""
assert next_state is not None
t = self.times
dt = np.diff(t, append=t[-1] + self.tup * 2)
for i, time in enumerate(t):
if sma is None:
print(i, time, dt[i])
continue
next_loop = f'spacer_high_{i + 1:02d}' if i < len(t) - 1 else next_state
sma.add_state(
state_name=f'spacer_high_{i:02d}',
state_timer=self.tup,
state_change_conditions={'Tup': f'spacer_low_{i:02d}'},
output_actions=[('BNC1', 255)], # To FPGA
)
sma.add_state(
state_name=f'spacer_low_{i:02d}',
state_timer=dt[i] - self.tup,
state_change_conditions={'Tup': next_loop},
output_actions=[],
)
[docs]
def find_spacers_from_fronts(self, fronts, fs=1000):
"""
Given the timestamps and polarities of a digital signal, returns the timestamps of each
signal. This method first finds the locations where there are n consecutive pulses of the
correct width then convolves this part of the signal with the template signal.
This method may be relaxed in order to make it robust to noise in the signal.
Parameters
----------
fronts : dict[str, numpy.array]
Dictionary with keys ('times', 'polarities') containing the timestamps and polarities
of the signal fronts, respectively.
fs : int
The sampling frequency of the DAQ signal.
Returns
-------
numpy.array
The times of the protocol spacer signals.
"""
n_pulses = (self.n_pulses * 2) - 1
is_pulse = np.isclose(np.diff(fronts['times']), self.tup, rtol=1e-2)
is_pulse = np.insert(is_pulse, 0, False)
(ind,) = np.where(is_pulse)
# Find consecutive pulses that are the correct length close together
max_d = 1.0 # look for fronts less than 1 second apart
consecutive = np.logical_and(np.diff(ind) == 2, np.diff(fronts['times'][ind]) < max_d)
consecutive = np.pad(consecutive, 1, 'constant', constant_values=False)
(edges,) = np.where(~consecutive)
spacer_times = []
for i in np.arange(edges.size - 1):
if edges[i + 1] - edges[i] == n_pulses: # This could be relaxed to allow for noise
idx = np.arange(ind[edges[i]], ind[edges[i + 1] - 1] + 1) # +1 to include final down
t = fronts['times'][idx]
ts = np.arange(t[0], t[-1], 1 / fs) # Evenly resample at given frequency
# Reconstruct trace where 1 = high, 0 = low
signal = np.zeros_like(ts)
ii = np.searchsorted(ts, t, side='left')
signal[ii[ii < len(signal)]] = fronts['polarities'][idx[ii < len(signal)]]
signal = np.cumsum(signal) + 1 # {-1, 0} -> {0, 1}
try:
(spacer,) = self.find_spacers(signal, fs=fs)
spacer_times.append(spacer + t[0])
except IndexError:
continue
return np.array(spacer_times)
[docs]
def find_spacers(self, signal, threshold=0.9, fs=1000):
"""
Find spacers in a voltage time series. Assumes that the signal is a digital signal between
0 and 1.
Parameters
----------
signal : numpy.ndarray
The signal in which to find the spacer.
threshold : float
The cross-correlation detection threshold.
fs : int
The sampling frequency of the DAQ signal.
Returns
-------
numpy.ndarray
An array containing the times of each spacer signal relative to the first sample.
"""
template = self.generate_template(fs=fs)
xcor = np.correlate(signal, template, mode='full') / np.sum(template)
idetect = np.where(xcor > threshold)[0]
iidetect = np.cumsum(np.diff(idetect, prepend=0) > 1)
nspacers = iidetect[-1]
tspacer = np.zeros(nspacers)
for i in range(nspacers):
ispacer = idetect[iidetect == i + 1]
imax = np.argmax(xcor[ispacer])
tspacer[i] = (ispacer[imax] - template.size + 1) / fs
return tspacer
[docs]
def find_spacers_from_timestamps(self, timestamps: np.ndarray, atol: np.float64 = 1e-4) -> np.ndarray:
"""
finds spacers in a series of timestamps. Returns the indices of the first spacer front
Parameters
----------
timestamps : np.ndarray
an array of the timestamps to check
atol : np.float64, optional
absolute tolerance for the squared sum of the residuals
Returns
-------
np.ndarray
an array of the indices of the first front of a spacer
"""
res = []
for i in range(timestamps.shape[0] - self.n_pulses * 2):
tcheck = timestamps[i : i + self.n_pulses * 2]
tcheck = tcheck - tcheck[0] + self.times[0]
res.append(np.sum((tcheck[:-1] - self.times) ** 2)) # squared sum of resituals
res = np.array(res)
return np.where(np.isclose(res, 0, atol=atol))[0]
[docs]
def find_spacers_from_positive_fronts(self, timestamps, fs=1000, prominence=4):
"""
finds spacers in timestamps that consist of only the positive fronts.
Parameters
----------
timestamps : np.ndarray
the positive fronts
fs : int, optional
sampling frequency used for resampling, by default 1000
prominence : int, optional
prominence for peak detection after convolution, passed to signal.find_peaks, by default 4
"""
def digitize(tstamps, fs):
# local helper to create a boolean vector from timestamped data sampled at fs
t = np.arange(tstamps[0], tstamps[-1], 1 / fs)
y = np.zeros_like(t)
y[np.digitize(tstamps, t) - 1] = 1
return y, t
y, t = digitize(timestamps, fs)
y_spacer, _ = digitize(self.times, fs)
# convolve to find spacers
y_c = np.convolve(y, y_spacer, mode='same')
peak_inds, _ = signal.find_peaks(y_c, prominence=prominence)
# adjust start time by spacer width
w = (self.times[-1] - self.times[0]) / 2
spacer_times_ = t[peak_inds] - w - self.times[0]
# convert spacer onset times to to index into timestamps
# spacer_ix = np.array([np.argmin((timestamps - t_s)**2) for t_s in spacer_times])
spacer_ix = []
spacer_times = []
for spacer_time in np.sort(spacer_times_):
dt = (timestamps - spacer_time) ** 2
if np.min(dt) > 1 / fs: # when first timestamp of spacer is missing
spacer_ix.append(np.nan) # the returned index is nan
spacer_times.append(spacer_time) # the returned time is the inferred time
else:
ix = np.argmin(dt)
spacer_ix.append(ix) # otherwise the returned index is the index of the fist spacer front
spacer_times.append(timestamps[ix]) # and it's corresponding timestamp
return np.array(spacer_ix), np.array(spacer_times)