"""Prune branches from skeletons."""
from __future__ import annotations
import logging
from collections.abc import Callable
import numpy as np
import numpy.typing as npt
# from skimage.morphology import binary_dilation, label
from skimage import morphology
from topostats.logs.logs import LOGGER_NAME
from topostats.tracing.skeletonize import getSkeleton
from topostats.tracing.tracingfuncs import coord_dist, genTracingFuncs, order_branch
from topostats.utils import convolve_skeleton
LOGGER = logging.getLogger(LOGGER_NAME)
[docs]
def prune_skeleton(image: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: float, **kwargs) -> npt.NDArray:
"""
Pruning skeletons using different pruning methods.
This is a thin wrapper to the methods provided within the pruning classes below.
Parameters
----------
image : npt.NDArray
Original image as 2D numpy array.
skeleton : npt.NDArray
Skeleton to be pruned.
pixel_to_nm_scaling : float
The pixel to nm scaling for pruning by length.
**kwargs
Pruning options passed to the respective method.
Returns
-------
npt.NDArray
An array of the skeleton with spurious branching artefacts removed.
"""
if image.shape != skeleton.shape:
raise AttributeError("Error image and skeleton are not the same size.")
return _prune_method(image, skeleton, pixel_to_nm_scaling, **kwargs)
[docs]
def _prune_method(image: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: float, **kwargs) -> Callable:
"""
Determine which skeletonize method to use.
Parameters
----------
image : npt.NDArray
Original image as 2D numpy array.
skeleton : npt.NDArray
Skeleton to be pruned.
pixel_to_nm_scaling : float
The pixel to nm scaling for pruning by length.
**kwargs
Pruning options passed to the respective method.
Returns
-------
Callable
Returns the function appropriate for the required skeletonizing method.
Raises
------
ValueError
Invalid method passed.
"""
method = kwargs.pop("method")
if method == "topostats":
return _prune_topostats(image, skeleton, pixel_to_nm_scaling, **kwargs)
# @maxgamill-sheffield I've read about a "Discrete Skeleton Evolultion" (DSE) method that might be useful
# @ns-rse (2024-06-04) : https://en.wikipedia.org/wiki/Discrete_skeleton_evolution
# https://link.springer.com/chapter/10.1007/978-3-540-74198-5_28
# https://dl.acm.org/doi/10.5555/1780074.1780108
# Python implementation : https://github.com/originlake/DSE-skeleton-pruning
raise ValueError(f"Invalid pruning method provided ({method}) please use one of 'topostats'.")
[docs]
def _prune_topostats(img: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: float, **kwargs) -> npt.NDArray:
"""
Prune using the original TopoStats method.
This is a modified version of the pubhlished Zhang method.
Parameters
----------
img : npt.NDArray
Image used to find skeleton, may be original heights or binary mask.
skeleton : npt.NDArray
Binary mask of the skeleton.
pixel_to_nm_scaling : float
The pixel to nm scaling for pruning by length.
**kwargs
Pruning options passed to the topostatsPrune class.
Returns
-------
npt.NDArray
The skeleton with spurious branches removed.
"""
return topostatsPrune(img, skeleton, pixel_to_nm_scaling, **kwargs).prune_skeleton()
# class pruneSkeleton: pylint: disable=too-few-public-methods
# """
# Class containing skeletonization pruning code from factory methods to functions dependent on the method.
# Pruning is the act of removing spurious branches commonly found when implementing skeletonization algorithms.
# Parameters
# ----------
# image : npt.NDArray
# Original image from which the skeleton derives including heights.
# skeleton : npt.NDArray
# Single-pixel-thick skeleton pertaining to features of the image.
# """
# def __init__(self, image: npt.NDArray, skeleton: npt.NDArray) -> None:
# """
# Initialise the class.
# Parameters
# ----------
# image : npt.NDArray
# Original image from which the skeleton derives including heights.
# skeleton : npt.NDArray
# Single-pixel-thick skeleton pertaining to features of the image.
# """
# self.image = image
# self.skeleton = skeleton
# def prune_skeleton( pylint: disable=dangerous-default-value
# self,
# prune_args: dict = {"pruning_method": "topostats"}, noqa: B006
# ) -> npt.NDArray:
# """
# Pruning skeletons.
# This is a thin wrapper to the methods provided within the pruning classes below.
# Parameters
# ----------
# prune_args : dict
# Method to use, default is 'topostats'.
# Returns
# -------
# npt.NDArray
# An array of the skeleton with spurious branching artefacts removed.
# """
# return self._prune_method(prune_args)
# def _prune_method(self, prune_args: str = None) -> Callable:
# """
# Determine which skeletonize method to use.
# Parameters
# ----------
# prune_args : str
# Method to use for skeletonizing, methods are 'topostats' other options are 'conv'.
# Returns
# -------
# Callable
# Returns the function appropriate for the required skeletonizing method.
# Raises
# ------
# ValueError
# Invalid method passed.
# """
# method = prune_args.pop("pruning_method")
# if method == "topostats":
# return self._prune_topostats(self.image, self.skeleton, prune_args)
# I've read about a "Discrete Skeleton Evolultion" (DSE) method that might be useful
# @ns-rse (2024-06-04) : Citation or link?
# raise ValueError(method)
# @staticmethod
# def _prune_topostats(img: npt.NDArray, skeleton: npt.NDArray, prune_args: dict) -> npt.NDArray:
# """
# Prune using the original TopoStats method.
# This is a modified version of the pubhlished Zhang method.
# Parameters
# ----------
# img : npt.NDArray
# Image used to find skeleton, may be original heights or binary mask.
# skeleton : npt.NDArray
# Binary mask of the skeleton.
# prune_args : dict
# Dictionary of pruning arguments. ??? Needs expanding on what the valid arguments are.
# Returns
# -------
# npt.NDArray
# The skeleton with spurious branches removed.
# """
# return topostatsPrune(img, skeleton, **prune_args).prune_skeleton()
# Might be worth renaming this to reflect what it does which is prune by length and height
[docs]
class topostatsPrune:
"""
Prune spurious skeletal branches based on their length and/or height.
Contains all the functions used in the original TopoStats pruning code written by Joe Betton.
Parameters
----------
img : npt.NDArray
Original image.
skeleton : npt.NDArray
Skeleton to be pruned.
pixel_to_nm_scaling : float
The pixel to nm scaling for pruning by length.
max_length : float
Maximum length of the branch to prune in nanometres (nm).
height_threshold : float
Absolute height value to remove branches below in nanometres (nm).
method_values : str
Method for obtaining the height thresholding values. Options are 'min' (minimum value of the branch),
'median' (median value of the branch) or 'mid' (ordered branch middle coordinate value).
method_outlier : str
Method for pruning brancvhes based on height. Options are 'abs' (below absolute value), 'mean_abs' (below the
skeleton mean - absolute threshold) or 'iqr' (below 1.5 * inter-quartile range).
"""
# pylint: disable=too-many-arguments
def __init__(
self,
img: npt.NDArray,
skeleton: npt.NDArray,
pixel_to_nm_scaling: float,
max_length: float = None,
height_threshold: float = None,
method_values: str = None,
method_outlier: str = None,
) -> None:
"""
Initialise the class.
Parameters
----------
img : npt.NDArray
Original image.
skeleton : npt.NDArray
Skeleton to be pruned.
pixel_to_nm_scaling : float
The pixel to nm scaling for pruning by length.
max_length : float
Maximum length of the branch to prune in nanometres (nm).
height_threshold : float
Absolute height value to remove branches below in nanometres (nm).
method_values : str
Method for obtaining the height thresholding values. Options are 'min' (minimum value of the branch),
'median' (median value of the branch) or 'mid' (ordered branch middle coordinate value).
method_outlier : str
Method for pruning brancvhes based on height. Options are 'abs' (below absolute value), 'mean_abs' (below
the skeleton mean - absolute threshold) or 'iqr' (below 1.5 * inter-quartile range).
"""
self.img = img
self.skeleton = skeleton.copy()
self.pixel_to_nm_scaling = pixel_to_nm_scaling
self.max_length = max_length
self.height_threshold = height_threshold
self.method_values = method_values
self.method_outlier = method_outlier
# Diverges from the change in layout to apply skeletonisation/pruning/tracing to individual grains and then process
# all grains in an image (possibly in parallel).
[docs]
def prune_skeleton(self) -> npt.NDArray:
"""
Prune skeleton by length and/or height.
If the class was initialised with both `max_length is not None` an d `height_threshold is not None` then length
based pruning is performed prior to height based pruning.
Returns
-------
npt.NDArray
A pruned skeleton.
"""
pruned_skeleton_mask = np.zeros_like(self.skeleton, dtype=np.uint8)
# print(f"{pruned_skeleton_mask=}")
labeled_skel = morphology.label(self.skeleton)
for i in range(1, labeled_skel.max() + 1):
single_skeleton = np.where(labeled_skel == i, 1, 0)
if self.max_length is not None:
LOGGER.debug(f": pruning.py : Pruning by length < {self.max_length}.")
single_skeleton = self._prune_by_length(single_skeleton, max_length=self.max_length)
if self.height_threshold is not None:
LOGGER.debug(": pruning.py : Pruning by height.")
single_skeleton = heightPruning(
self.img,
single_skeleton,
height_threshold=self.height_threshold,
method_values=self.method_values,
method_outlier=self.method_outlier,
).skeleton_pruned
# skeletonise to remove nibs
# Discovered this caused an error when writing tests...
#
# numpy.core._exceptions._UFuncOutputCastingError: Cannot cast ufunc 'add' output from dtype('int8') to
# dtype('bool') with casting...
# pruned_skeleton_mask += getSkeleton(self.img, single_skeleton, method="zhang").get_skeleton()
pruned_skeleton = getSkeleton(self.img, single_skeleton, method="zhang").get_skeleton()
pruned_skeleton_mask += pruned_skeleton.astype(dtype=np.uint8)
return pruned_skeleton_mask
[docs]
def _prune_by_length( # pylint: disable=too-many-locals # noqa: C901
self, single_skeleton: npt.NDArray, max_length: float
) -> npt.NDArray:
"""
Remove hanging branches from a skeleton by their length.
This is an iterative process as these are a persistent problem in the overall tracing process.
Parameters
----------
single_skeleton : npt.NDArray
Binary array of the skeleton.
max_length : float
Maximum length of the branch to prune in nanometers (nm).
Returns
-------
npt.NDArray
Pruned skeleton as binary array.
"""
# get segments via convolution and removing junctions
conv_skeleton = convolve_skeleton(single_skeleton)
conv_skeleton[conv_skeleton == 3] = 0
labeled_segments = morphology.label(conv_skeleton.astype(bool))
for segment_idx in range(1, labeled_segments.max() + 1):
# get single segment with endpoints==2
segment = np.where(labeled_segments == segment_idx, conv_skeleton, 0)
# get segment length
ordered_coords = order_branch(np.where(segment != 0, 1, 0), [0, 0])
segment_length = coord_dist(ordered_coords, self.pixel_to_nm_scaling)[-1]
# check if endpoint
if 2 in segment and segment_length < max_length:
# prune
single_skeleton[labeled_segments == segment_idx] = 0
return rm_nibs(single_skeleton)
[docs]
@staticmethod
def _find_branch_ends(coordinates: list) -> list:
"""
Identify branch ends.
This is achieved by iterating through the coordinates and assessing the local pixel area. Ends have only one
adjacent pixel.
Parameters
----------
coordinates : list
List of x, y coordinates of a branch.
Returns
-------
list
List of x, y coordinates of the branch ends.
"""
branch_ends = []
# Most of the branch ends are just points with one neighbour
for x, y in coordinates:
if genTracingFuncs.count_and_get_neighbours(x, y, coordinates)[0] == 1:
branch_ends.append([x, y])
return branch_ends
[docs]
class heightPruning: # pylint: disable=too-many-instance-attributes
"""
Pruning of branches based on height.
Parameters
----------
image : npt.NDArray
Original image, typically the height data.
skeleton : npt.NDArray
Skeleton to prune branches from.
max_length : float
Maximum length of the branch to prune in nanometres (nm).
height_threshold : float
Absolute height value to remove branches below in nanometers (nm).
method_values : str
Method of obtaining the height thresholding values. Options are 'min' (minimum value of the branch),
'median' (median value of the branch) or 'mid' (ordered branch middle coordinate value).
method_outlier : str
Method to prune branches based on height. Options are 'abs' (below absolute value), 'mean_abs' (below the
skeleton mean - absolute threshold) or 'iqr' (below 1.5 * inter-quartile range).
""" # numpydoc: ignore=PR01
def __init__(
self,
image: npt.NDArray,
skeleton: npt.NDArray,
max_length: float = None,
height_threshold: float = None,
method_values: str = None,
method_outlier: str = None,
) -> None:
"""
Initialise the class.
Parameters
----------
image : npt.NDArray
Original image, typically the height data.
skeleton : npt.NDArray
Skeleton to prune branches from.
max_length : float
Maximum length of the branch to prune in nanometres (nm).
height_threshold : float
Absolute height value to remove branches below in nanometers (nm).
method_values : str
Method of obtaining the height thresholding values. Options are 'min' (minimum value of the branch),
'median' (median value of the branch) or 'mid' (ordered branch middle coordinate value).
method_outlier : str
Method to prune branches based on height. Options are 'abs' (below absolute value), 'mean_abs' (below the
skeleton mean - absolute threshold) or 'iqr' (below 1.5 * inter-quartile range).
"""
self.image = image
self.skeleton = skeleton
self.skeleton_convolved = None
self.skeleton_branches = None
self.skeleton_branches_labelled = None
self.max_length = max_length
self.height_threshold = height_threshold
self.method_values = method_values
self.method_outlier = method_outlier
self.convolve_skeleton()
self.segment_skeleton()
self.label_branches()
self.skeleton_pruned = self.height_prune()
[docs]
def convolve_skeleton(self) -> None:
"""Convolve skeleton."""
self.skeleton_convolved = convolve_skeleton(self.skeleton)
[docs]
def segment_skeleton(self) -> None:
"""Convolve skeleton and break into segments at nodes/junctions."""
self.skeleton_branches = np.where(self.skeleton_convolved == 3, 0, self.skeleton)
[docs]
def label_branches(self) -> None:
"""Label segmented branches."""
self.skeleton_branches_labelled = morphology.label(self.skeleton_branches)
[docs]
def _get_branch_mins(self, segments: npt.NDArray) -> npt.NDArray:
"""
Collect the minimum height value of each individually labeled branch.
Parameters
----------
segments : npt.NDArray
Integer labeled array matching the dimensions of the image.
Returns
-------
npt.NDArray
Array of minimum values of each branch index -1.
"""
return np.array([np.min(self.image[segments == i]) for i in range(1, segments.max() + 1)])
[docs]
def _get_branch_middles(self, segments: npt.NDArray) -> npt.NDArray:
"""
Collect the positionally ordered middle height value of each labeled branch.
Where the branch has an even amount of points, average the two middle heights.
Parameters
----------
segments : npt.NDArray
Integer labeled array matching the dimensions of the image.
Returns
-------
npt.NDArray
Array of middle values of each branch.
"""
branch_middles = np.zeros(segments.max())
for i in range(1, segments.max() + 1):
segment = np.where(segments == i, 1, 0)
if segment.sum() > 2:
# sometimes start is not found ?
start = np.argwhere(convolve_skeleton(segment) == 2)[0]
ordered_coords = order_branch_from_end(segment, start)
# if even no. points, average two middles
middle_idx, middle_remainder = (len(ordered_coords) + 1) // 2 - 1, (len(ordered_coords) + 1) % 2
mid_coord = ordered_coords[[middle_idx, middle_idx + middle_remainder]]
# height = image[mid_coord[:, 0], mid_coord[:, 1]].mean()
height = self.image[mid_coord[:, 0], mid_coord[:, 1]].mean()
else:
# if 2 points, need to average them
height = self.image[segment == 1].mean()
branch_middles[i - 1] += height
return branch_middles
[docs]
@staticmethod
def _get_abs_thresh_idx(height_values: npt.NDArray, threshold: float | int) -> npt.NDArray:
"""
Identify indices of labelled branches whose height values are less than a given threshold.
Parameters
----------
height_values : npt.NDArray
Array of each branches heights.
threshold : float | int
Threshold for heights.
Returns
-------
npt.NDArray
Branch indices which are less than threshold.
"""
return np.asarray(np.where(height_values < threshold))[0] + 1
[docs]
@staticmethod
def _get_mean_abs_thresh_idx(
height_values: npt.NDArray, threshold: float | int, image: npt.NDArray, skeleton: npt.NDArray
) -> npt.NDArray:
"""
Identify indices of labelled branch whose height values are less than mean skeleton height - absolute threshold.
For DNA a threshold of 0.85nm (the depth of the major groove) would ideally remove all segments whose lowest
point is < mean(height) - 0.85nm, i.e. 1.15nm.
Parameters
----------
height_values : npt.NDArray
Array of branches heights.
threshold : float | int
Threshold to be subtracted from mean heights.
image : npt.NDArray
Original image of heights.
skeleton : npt.NDArray
Binary array of skeleton used to identify heights from original image to use.
Returns
-------
npt.NDArray
Branch indices which are less than mean(height) - threshold.
"""
avg = image[skeleton == 1].mean()
print(f"{avg=}")
print(f"{(avg-threshold)=}")
return np.asarray(np.where(np.asarray(height_values) < (avg - threshold)))[0] + 1
[docs]
@staticmethod
def _get_iqr_thresh_idx(image: npt.NDArray, segments: npt.NDArray) -> npt.NDArray:
"""
Identify labelled branch indices whose heights are less than 1.5 x interquartile range of all heights.
Parameters
----------
image : npt.NDArray
Original image with heights.
segments : npt.NDArray
Array of skeleton branches.
Returns
-------
npt.NDArray
Branch indices where heights are < 1.5 * inter-quartile range.
"""
coords = np.argwhere(segments != 0)
heights = image[coords[:, 0], coords[:, 1]] # all skel heights else distribution isn't representitive
q75, q25 = np.percentile(heights, [75, 25])
iqr = q75 - q25
threshold = q25 - 1.5 * iqr
print(f"{q25=}")
print(f"{q75=}")
print(f"{threshold=}")
low_coords = coords[heights < threshold]
low_segment_idxs = []
low_segment_mins = []
# iterate through each branch segment and see if any low_coords are in a branch
for segment_num in range(1, segments.max() + 1):
segment_coords = np.argwhere(segments == segment_num)
for low_coord in low_coords:
place = np.isin(segment_coords, low_coord).all(axis=1)
if place.any():
low_segment_idxs.append(segment_num)
low_segment_mins.append(image[segments == segment_num].min())
break
return np.array(low_segment_idxs)[np.argsort(low_segment_mins)] # sort in order of ascending mins
[docs]
@staticmethod
def check_skeleton_one_object(skeleton: npt.NDArray) -> bool:
"""
Ensure that the skeleton hasn't been broken up upon removing a segment.
Parameters
----------
skeleton : npt.NDArray
2D single pixel thick array.
Returns
-------
bool
True or False depending on whether there is 1 or !1 objects.
"""
skeleton = np.where(skeleton != 0, 1, 0)
return morphology.label(skeleton).max() == 1
[docs]
def filter_segments(self, segments: npt.NDArray) -> npt.NDArray:
"""
Identify and remove segments of a skeleton based on the underlying image height.
Parameters
----------
segments : npt.NDArray
A labelled 2D array of skeleton segments.
Returns
-------
npt.NDArray
The original skeleton without the segments identified by the height criteria.
"""
# Obtain the height of each branch via the min | median | mid methods
if self.method_values == "min":
height_values = self._get_branch_mins(segments)
elif self.method_values == "median":
height_values = self._get_branch_medians(segments)
elif self.method_values == "mid":
height_values = self._get_branch_middles(segments)
# threshold heights to obtain indexes of branches to be removed
if self.method_outlier == "abs":
idxs = self._get_abs_thresh_idx(height_values, self.height_threshold)
elif self.method_outlier == "mean_abs":
idxs = self._get_mean_abs_thresh_idx(height_values, self.height_threshold, self.image, self.skeleton)
elif self.method_outlier == "iqr":
idxs = self._get_iqr_thresh_idx(self.image, segments)
# Only remove the bridge if the skeleton remains a single object.
skeleton_rtn = self.skeleton.copy()
for i in idxs:
temp_skel = self.skeleton.copy()
temp_skel[segments == i] = 0
if self.check_skeleton_one_object(temp_skel):
skeleton_rtn[segments == i] = 0
return skeleton_rtn
# def remove_bridges(self) -> npt.NDArray:
# """
# Identify and remove skeleton bridges using the underlying image height.
# Bridges cross the skeleton in places they shouldn't and are defined as an internal branch and thus have no
# endpoints. They occur due to poor thresholding creating holes in the mask, creating false "bridges" which
# misrepresent the skeleton of the molecule.
# Returns
# -------
# npt.NDArray
# A skeleton with internal branches removed by height.
# """
# conv = convolve_skeleton(self.skeleton)
# # Split the skeleton into branches by removing junctions/nodes and label
# nodeless = np.where(conv == 3, 0, conv)
# segments = morphology.label(np.where(nodeless != 0, 1, 0))
# # bridges should not concern endpoints so remove these
# for i in range(1, segments.max() + 1):
# if (conv[segments == i] == 2).any():
# segments[segments == i] = 0
# segments = morphology.label(np.where(segments != 0, 1, 0))
# # filter the segments based on height criteria
# return self.filter_segments(segments)
[docs]
def height_prune(self) -> npt.NDArray:
"""
Identify and remove spurious branches (containing endpoints) using the underlying image height.
Returns
-------
npt.NDArray
A skeleton with outer branches removed by height.
"""
conv = convolve_skeleton(self.skeleton)
segments = self._split_skeleton(conv)
# height pruning should only concern endpoints so remove internal connections
for i in range(1, segments.max() + 1):
if not (conv[segments == i] == 2).any():
segments[segments == i] = 0
segments = morphology.label(np.where(segments != 0, 1, 0))
# filter the segments based on height criteria
return self.filter_segments(segments)
[docs]
@staticmethod
def _split_skeleton(skeleton: npt.NDArray) -> npt.NDArray:
"""
Split the skeleton into branches by removing junctions/nodes and label branches.
Parameters
----------
skeleton : npt.NDArray
Convolved skeleton to be split. This should have nodes labelled as 3, ends as 2 and all other points as 1.
Returns
-------
npt.NDArray
Removes the junctions (3) and returns all remaining sections as labelled segments.
"""
nodeless = np.where(skeleton == 3, 0, skeleton)
return morphology.label(np.where(nodeless != 0, 1, 0))
[docs]
def order_branch_from_end(nodeless: npt.NDArray, start: list, max_length: float = np.inf) -> npt.NDArray:
"""
Take a linear branch and orders its coordinates starting from a specific endpoint.
NB - It may be possible to use np.lexsort() to order points, see topostats.measure.feret.sort_coords() for an
example of how to sort by row or column coordinates, which end of the branch this is from probably doesn't matter
as one only wants to find the mid-point I think.
Parameters
----------
nodeless : npt.NDArray
A 2D binary array where there are no crossing pixels.
start : list
A coordinate to start closest to / at.
max_length : float, optional
The maximum length to order along the branch, in pixels, by default np.inf.
Returns
-------
npt.NDArray
The input linear branch ordered from the start coordinate.
"""
dist = 0
# add starting point to ordered array
ordered = []
ordered.append(start)
nodeless[start[0], start[1]] = 0 # remove from array
# iterate to order the rest of the points
current_point = ordered[-1] # get last point
area, _ = local_area_sum(nodeless, current_point) # look at local area
local_next_point = np.argwhere(
area.reshape(
(
3,
3,
)
)
== 1
) - (1, 1)
dist += np.sqrt(2) if abs(local_next_point).sum() > 1 else 1
while len(local_next_point) != 0 and dist <= max_length:
next_point = (current_point + local_next_point)[0]
# find where to go next
ordered.append(next_point)
nodeless[next_point[0], next_point[1]] = 0 # set value to zero
current_point = ordered[-1] # get last point
area, _ = local_area_sum(nodeless, current_point) # look at local area
local_next_point = np.argwhere(
area.reshape(
(
3,
3,
)
)
== 1
) - (1, 1)
dist += np.sqrt(2) if abs(local_next_point).sum() > 1 else 1
return np.array(ordered)
[docs]
def rm_nibs(skeleton): # pylint: disable=too-many-locals
"""
Remove single pixel branches (nibs) not identified by nearest neighbour algorithms as there may be >2 neighbours.
Parameters
----------
skeleton : npt.NDArray
A single pixel thick trace.
Returns
-------
npt.NDArray
A skeleton with single pixel nibs removed.
"""
conv_skel = convolve_skeleton(skeleton)
nodes = np.where(conv_skel == 3, 1, 0)
labeled_nodes = morphology.label(nodes)
nodeless = np.where((conv_skel == 1) | (conv_skel == 2), 1, 0)
labeled_nodeless = morphology.label(nodeless)
size_1_idxs = []
for node_num in range(1, labeled_nodes.max() + 1):
node = np.where(labeled_nodes == node_num, 1, 0)
dil = morphology.binary_dilation(node, footprint=np.ones((3, 3)))
minus = np.where(dil != node, 1, 0)
idxs = labeled_nodeless[minus == 1]
idxs = idxs[idxs != 0]
for nodeless_num in np.unique(idxs):
# if all of the branch is in surrounding node area
branch_size = (labeled_nodeless == nodeless_num).sum()
branch_idx_in_surr_area = (idxs == nodeless_num).sum()
if branch_size == branch_idx_in_surr_area:
size_1_idxs.append(nodeless_num)
unique, counts = np.unique(np.array(size_1_idxs), return_counts=True)
for k, count in enumerate(counts):
if count == 1:
skeleton[labeled_nodeless == unique[k]] = 0
return skeleton
[docs]
def local_area_sum(img: npt.NDArray, point: list | tuple | npt.NDArray) -> tuple:
"""
Evaluate the local area around a point in a binary map.
Parameters
----------
img : npt.NDArray
Binary array of image.
point : list | tuple | npt.NDArray
Coordinates of a point within the binary_map.
Returns
-------
tuple
Tuple consisting of an array values of the local coordinates around the point and the number of neighbours
around the point.
"""
if img[point[0], point[1]] > 1:
raise ValueError("binary_map is not binary!")
# Capture if point is on the top or left edge or array
try:
local_pixels = img[point[0] - 1 : point[0] + 2, point[1] - 1 : point[1] + 2].flatten()
except IndexError as exc:
raise IndexError("Point can not be on the edge of an array.") from exc
# Above does not capture points on right or bottom since slicing arrays beyond their indexes simply extends them
# Therefore check that we have an array of length 9
if local_pixels.shape[0] == 9:
local_pixels[4] = 0 # ensure centre is 0
if local_pixels.sum() <= 8:
return local_pixels, local_pixels.sum()
raise ValueError("'binary_map' is not binary!")
raise IndexError("'point' is on right or bottom edge of 'binary_map'")