Source code for iblatlas.streamlines.utils

import pandas as pd
import numpy as np
from iblutil.numerical import ismember
import matplotlib.pyplot as plt
from iblatlas.atlas import get_bc, BrainAtlas, aws, AllenAtlas
from iblatlas.regions import BrainRegions


def _download_depth_files(file_name):
    """
    Download and return path to relevant file
    :param file_name:
    :return:
    """
    file_path = BrainAtlas._get_cache_dir().joinpath('depths', file_name)
    if not file_path.exists():
        file_path.parent.mkdir(exist_ok=True, parents=True)
        aws.s3_download_file(f'atlas/depths/{file_path.name}', file_path)

    return file_path


[docs] def xyz_to_depth(xyz, per=True, res_um=25): """ For a given xyz coordinates return the depth from the surface of the cortex. The depth is returned as a percentage if per=True and in um if per=False. Note the lookup will only work for xyz cooordinates that are in the Isocortex of the Allen volume. If coordinates outside of this region are given then the depth is returned as nan. Parameters ---------- xyz : numpy.array An (n, 3) array of Cartesian coordinates. The order is ML, AP, DV and coordinates should be given in meters relative to bregma. per : bool Whether to do the lookup in percentage from the surface of the cortex or depth in um from the surface of the cortex. res_um : float or int The resolution of the brain atlas to do the depth lookup Returns ------- numpy.array The depths from the surface of the cortex for each cartesian coordinate. If the coordinate does not lie within the Isocortex, depth value returned is nan """ ind_flat = np.load(_download_depth_files(f'depths_ind_{res_um}.npy')) depth_file = f'depths_per_{res_um}.npy' if per else f'depths_um_{res_um}.npy' depths = np.load(_download_depth_files(depth_file)) bc = get_bc(res_um=res_um) ixyz = bc.xyz2i(xyz, mode='clip') iravel = np.ravel_multi_index((ixyz[:, 1], ixyz[:, 0], ixyz[:, 2]), (bc.ny, bc.nx, bc.nz)) a, b = ismember(iravel, ind_flat) lookup_depths = np.full(iravel.shape, np.nan, dtype=np.float32) lookup_depths[a] = depths[b] return lookup_depths
[docs] def get_mask(volume='annotation', br=None): """ Generate a mask to plot results onto Parameters: ----------- volume : str, optional The type of volume to project. Options are: - 'image': Projects the anatomical image using max intensity. - 'annotation': Projects the labeled regions and maps them to RGB colors. - 'boundary': Projects labeled regions and extracts anatomical boundaries. Default is 'annotation'. br : BrainRegions, optional An instance of the BrainRegions If None, a default BrainRegions is initialized. Returns: -------- img : np.ndarray The resulting 2D flatmap projection image, either grayscale, RGB, or binary mask, depending on the selected volume type. """ br = br or BrainRegions() if volume == 'image': img = np.load(_download_depth_files('dorsal_image.npy')) elif volume == 'annotation': img = np.load(_download_depth_files('dorsal_annotation.npy')) img = br.rgb[img] elif volume == 'boundary': img = np.load(_download_depth_files('dorsal_annotation.npy')) img = AllenAtlas.compute_boundaries(img) return img
[docs] def validate_aggr(aggr: str) -> None: """ Validates if the provided aggregation type is valid. Parameters: ---------- aggr : str The aggregation method to validate (e.g., 'mean', 'sum'). Raises: ------ AssertionError If the aggregation type is not one of the allowed values. """ poss_aggrs = ['sum', 'count', 'mean', 'std', 'median', 'min', 'max', 'first', 'last'] assert aggr in poss_aggrs, f"Aggregation must be one of {poss_aggrs}."
[docs] def project_volume_onto_flatmap(vol: np.ndarray, res_um: int = 25, aggr: str = 'mean', plot: bool = True, cmap: str = 'viridis', clevels: tuple = None, ax: plt.Axes = None) -> np.ndarray: """ Projects a 3D volume onto a 2D flatmap by aggregating values along streamline paths. Parameters: ---------- vol : np.ndarray A 3D numpy array representing the volume data to be projected. res_um : int The resolution of the volume. Must be one of 10, 25 or 50. aggr : str, optional The aggregation method ('sum', 'count', 'mean', etc.), default is 'mean'. plot : bool, optional Whether to plot the resulting projection, default is True. cmap : str, optional The colormap to use for the plot, default is 'viridis'. clevels : tuple, optional The color limits to use for the plot, default is None. ax : matplotlib.axes.Axes, optional The axes on which to plot, default is None. Returns: ------- np.ndarray The projected 2D array onto the flatmap. matplotlib.figure.Figure Matplotlib figure object if plot=True, otherwise None. matplotlib.axes.Axes Matplotlib axes object if plot=True, otherwise None. """ bc = get_bc(res_um) assert vol.shape == (bc.ny, bc.nx, bc.nz), f"Volume does not have expected shape of {(bc.ny, bc.nx, bc.nz)}" # Validate the aggregation type validate_aggr(aggr) # Load the streamline paths path_df = pd.read_parquet(_download_depth_files(f'paths_{res_um}.pqt')) # Extract values from the volume using the path lookup path_df['vals'] = vol.flat[path_df['lookup'].values] # Aggregate the values along each path flat_df = path_df.groupby('paths').vals.agg(aggr) # Project the aggregated data onto the flatmap return _project_onto_flatmap(flat_df, plot=plot, cmap=cmap, clevels=clevels, ax=ax)
[docs] def project_points_onto_flatmap(xyz: np.ndarray, values: np.ndarray, res_um: int = 25, aggr: str = 'mean', plot: bool = True, cmap: str = 'viridis', clevels: tuple = None, ax: plt.Axes = None) -> np.ndarray: """ Projects 3D points with associated values onto a 2D flatmap. Parameters: ---------- xyz : np.ndarray An array containing xyz coordinates of the points to be projected. xyz values should be given in metres values : np.ndarray A 1D array of values to associate with the points. res : int The resolution to load the corresponding streamline paths. aggr : str, optional The aggregation method ('sum', 'count', 'mean', etc.), default is 'mean'. plot : bool, optional Whether to plot the resulting projection, default is True. cmap : str, optional The colormap to use for the plot, default is 'viridis'. clevels : tuple, optional The color limits to use for the plot, default is None. ax : matplotlib.axes.Axes, optional The axes on which to plot, default is None. Returns: ------- np.ndarray The projected 2D array onto the flatmap. matplotlib.figure.Figure Matplotlib figure object if plot=True, otherwise None. matplotlib.axes.Axes Matplotlib axes object if plot=True, otherwise None. """ # Ensure that xyz and values have matching dimensions assert xyz.shape[0] == values.size, "xyz must have the same number of rows as values." # Validate the aggregation type validate_aggr(aggr) # Get the boundary coordinates for the given resolution bc = get_bc(res_um) # Convert coordinates xyz to indices in volume ixyz = bc.xyz2i(xyz, mode='clip') # Create DataFrame of values and their corresponding flattened indices val_df = pd.DataFrame() val_df['vals'] = values val_df['lookup'] = np.ravel_multi_index((ixyz[:, 1], ixyz[:, 0], ixyz[:, 2]), (bc.ny, bc.nx, bc.nz)) # Remove entries with invalid lookups val_df = val_df[val_df['lookup'] != 0] # Load streamline paths path_df = pd.read_parquet(_download_depth_files(f'paths_{res_um}.pqt')) # Restrict dataframes to overlapping locations vals_in_paths, _ = ismember(val_df['lookup'].values, path_df['lookup'].values) val_df = val_df[vals_in_paths] # Keep only paths that match the val lookups paths_in_vals, _ = ismember(path_df['lookup'].values, val_df['lookup'].values) path_df = path_df[paths_in_vals] # Merge path data with values and aggregate by path flat_df = path_df.merge(val_df, on='lookup', how='left').groupby('paths').vals.agg(aggr) # Project the aggregated data onto the flatmap return _project_onto_flatmap(flat_df, plot=plot, cmap=cmap, clevels=clevels, ax=ax)
def _project_onto_flatmap(flat_df: pd.Series, plot: bool = True, cmap: str = 'viridis', clevels: tuple = None, ax: plt.Axes = None) -> np.ndarray: """ Function to project aggregated data onto a 2D flatmap. Parameters: ---------- flat_df : pd.Series The data to project onto the flatmap. plot : bool, optional Whether to plot the resulting projection, default is True. cmap : str, optional The colormap to use for the plot, default is 'viridis'. clevels : tuple, optional The color limits to use for the plot, default is None. ax : matplotlib.axes.Axes, optional The axes on which to plot, default is None. Returns: ------- np.ndarray The projected 2D array onto the flatmap. plt.figure Matplotlib figure object if plot=True, otherwise None. plt.axes Matplotlib axes object if plot=True, otherwise None. """ # Load the flatmap flatmap = np.load(_download_depth_files('dorsal_flatmap.npy')) # Find the indices in the flatmap corresponding to the projected data _, b = ismember(flat_df.index, flatmap.flat) # Initialize the projection array with zeros proj = np.zeros(np.prod(flatmap.shape)) # Assign the values from flat_df to the projection array proj[b] = flat_df.values # Reshape the 1D projection into the 2D flatmap shape proj = proj.reshape(flatmap.shape) # Plot the result if requested if plot: if ax: fig = ax.get_figure() else: fig, ax = plt.subplots() if clevels is None: clevels = (np.nanmin(proj), np.nanmax(proj)) ax.imshow(proj, cmap=cmap, vmin=clevels[0], vmax=clevels[1]) return proj, fig, ax else: return proj