Source code for topostats.tracing.pruning

"""Prune branches from skeletons."""

from __future__ import annotations

import logging
from collections.abc import Callable

import numpy as np
import numpy.typing as npt

# from skimage.morphology import binary_dilation, label
from skimage import morphology

from topostats.logs.logs import LOGGER_NAME
from topostats.tracing.skeletonize import getSkeleton
from topostats.tracing.tracingfuncs import coord_dist, genTracingFuncs, order_branch
from topostats.utils import convolve_skeleton

LOGGER = logging.getLogger(LOGGER_NAME)


[docs] def prune_skeleton(image: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: float, **kwargs) -> npt.NDArray: """ Pruning skeletons using different pruning methods. This is a thin wrapper to the methods provided within the pruning classes below. Parameters ---------- image : npt.NDArray Original image as 2D numpy array. skeleton : npt.NDArray Skeleton to be pruned. pixel_to_nm_scaling : float The pixel to nm scaling for pruning by length. **kwargs Pruning options passed to the respective method. Returns ------- npt.NDArray An array of the skeleton with spurious branching artefacts removed. """ if image.shape != skeleton.shape: raise AttributeError("Error image and skeleton are not the same size.") return _prune_method(image, skeleton, pixel_to_nm_scaling, **kwargs)
[docs] def _prune_method(image: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: float, **kwargs) -> Callable: """ Determine which skeletonize method to use. Parameters ---------- image : npt.NDArray Original image as 2D numpy array. skeleton : npt.NDArray Skeleton to be pruned. pixel_to_nm_scaling : float The pixel to nm scaling for pruning by length. **kwargs Pruning options passed to the respective method. Returns ------- Callable Returns the function appropriate for the required skeletonizing method. Raises ------ ValueError Invalid method passed. """ method = kwargs.pop("method") if method == "topostats": return _prune_topostats(image, skeleton, pixel_to_nm_scaling, **kwargs) # @maxgamill-sheffield I've read about a "Discrete Skeleton Evolultion" (DSE) method that might be useful # @ns-rse (2024-06-04) : https://en.wikipedia.org/wiki/Discrete_skeleton_evolution # https://link.springer.com/chapter/10.1007/978-3-540-74198-5_28 # https://dl.acm.org/doi/10.5555/1780074.1780108 # Python implementation : https://github.com/originlake/DSE-skeleton-pruning raise ValueError(f"Invalid pruning method provided ({method}) please use one of 'topostats'.")
[docs] def _prune_topostats(img: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: float, **kwargs) -> npt.NDArray: """ Prune using the original TopoStats method. This is a modified version of the pubhlished Zhang method. Parameters ---------- img : npt.NDArray Image used to find skeleton, may be original heights or binary mask. skeleton : npt.NDArray Binary mask of the skeleton. pixel_to_nm_scaling : float The pixel to nm scaling for pruning by length. **kwargs Pruning options passed to the topostatsPrune class. Returns ------- npt.NDArray The skeleton with spurious branches removed. """ return topostatsPrune(img, skeleton, pixel_to_nm_scaling, **kwargs).prune_skeleton()
# class pruneSkeleton: pylint: disable=too-few-public-methods # """ # Class containing skeletonization pruning code from factory methods to functions dependent on the method. # Pruning is the act of removing spurious branches commonly found when implementing skeletonization algorithms. # Parameters # ---------- # image : npt.NDArray # Original image from which the skeleton derives including heights. # skeleton : npt.NDArray # Single-pixel-thick skeleton pertaining to features of the image. # """ # def __init__(self, image: npt.NDArray, skeleton: npt.NDArray) -> None: # """ # Initialise the class. # Parameters # ---------- # image : npt.NDArray # Original image from which the skeleton derives including heights. # skeleton : npt.NDArray # Single-pixel-thick skeleton pertaining to features of the image. # """ # self.image = image # self.skeleton = skeleton # def prune_skeleton( pylint: disable=dangerous-default-value # self, # prune_args: dict = {"pruning_method": "topostats"}, noqa: B006 # ) -> npt.NDArray: # """ # Pruning skeletons. # This is a thin wrapper to the methods provided within the pruning classes below. # Parameters # ---------- # prune_args : dict # Method to use, default is 'topostats'. # Returns # ------- # npt.NDArray # An array of the skeleton with spurious branching artefacts removed. # """ # return self._prune_method(prune_args) # def _prune_method(self, prune_args: str = None) -> Callable: # """ # Determine which skeletonize method to use. # Parameters # ---------- # prune_args : str # Method to use for skeletonizing, methods are 'topostats' other options are 'conv'. # Returns # ------- # Callable # Returns the function appropriate for the required skeletonizing method. # Raises # ------ # ValueError # Invalid method passed. # """ # method = prune_args.pop("pruning_method") # if method == "topostats": # return self._prune_topostats(self.image, self.skeleton, prune_args) # I've read about a "Discrete Skeleton Evolultion" (DSE) method that might be useful # @ns-rse (2024-06-04) : Citation or link? # raise ValueError(method) # @staticmethod # def _prune_topostats(img: npt.NDArray, skeleton: npt.NDArray, prune_args: dict) -> npt.NDArray: # """ # Prune using the original TopoStats method. # This is a modified version of the pubhlished Zhang method. # Parameters # ---------- # img : npt.NDArray # Image used to find skeleton, may be original heights or binary mask. # skeleton : npt.NDArray # Binary mask of the skeleton. # prune_args : dict # Dictionary of pruning arguments. ??? Needs expanding on what the valid arguments are. # Returns # ------- # npt.NDArray # The skeleton with spurious branches removed. # """ # return topostatsPrune(img, skeleton, **prune_args).prune_skeleton() # Might be worth renaming this to reflect what it does which is prune by length and height
[docs] class topostatsPrune: """ Prune spurious skeletal branches based on their length and/or height. Contains all the functions used in the original TopoStats pruning code written by Joe Betton. Parameters ---------- img : npt.NDArray Original image. skeleton : npt.NDArray Skeleton to be pruned. pixel_to_nm_scaling : float The pixel to nm scaling for pruning by length. max_length : float Maximum length of the branch to prune in nanometres (nm). height_threshold : float Absolute height value to remove branches below in nanometres (nm). method_values : str Method for obtaining the height thresholding values. Options are 'min' (minimum value of the branch), 'median' (median value of the branch) or 'mid' (ordered branch middle coordinate value). method_outlier : str Method for pruning brancvhes based on height. Options are 'abs' (below absolute value), 'mean_abs' (below the skeleton mean - absolute threshold) or 'iqr' (below 1.5 * inter-quartile range). """ # pylint: disable=too-many-arguments def __init__( self, img: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: float, max_length: float = None, height_threshold: float = None, method_values: str = None, method_outlier: str = None, ) -> None: """ Initialise the class. Parameters ---------- img : npt.NDArray Original image. skeleton : npt.NDArray Skeleton to be pruned. pixel_to_nm_scaling : float The pixel to nm scaling for pruning by length. max_length : float Maximum length of the branch to prune in nanometres (nm). height_threshold : float Absolute height value to remove branches below in nanometres (nm). method_values : str Method for obtaining the height thresholding values. Options are 'min' (minimum value of the branch), 'median' (median value of the branch) or 'mid' (ordered branch middle coordinate value). method_outlier : str Method for pruning brancvhes based on height. Options are 'abs' (below absolute value), 'mean_abs' (below the skeleton mean - absolute threshold) or 'iqr' (below 1.5 * inter-quartile range). """ self.img = img self.skeleton = skeleton.copy() self.pixel_to_nm_scaling = pixel_to_nm_scaling self.max_length = max_length self.height_threshold = height_threshold self.method_values = method_values self.method_outlier = method_outlier # Diverges from the change in layout to apply skeletonisation/pruning/tracing to individual grains and then process # all grains in an image (possibly in parallel).
[docs] def prune_skeleton(self) -> npt.NDArray: """ Prune skeleton by length and/or height. If the class was initialised with both `max_length is not None` an d `height_threshold is not None` then length based pruning is performed prior to height based pruning. Returns ------- npt.NDArray A pruned skeleton. """ pruned_skeleton_mask = np.zeros_like(self.skeleton, dtype=np.uint8) # print(f"{pruned_skeleton_mask=}") labeled_skel = morphology.label(self.skeleton) for i in range(1, labeled_skel.max() + 1): single_skeleton = np.where(labeled_skel == i, 1, 0) if self.max_length is not None: LOGGER.debug(f": pruning.py : Pruning by length < {self.max_length}.") single_skeleton = self._prune_by_length(single_skeleton, max_length=self.max_length) if self.height_threshold is not None: LOGGER.debug(": pruning.py : Pruning by height.") single_skeleton = heightPruning( self.img, single_skeleton, height_threshold=self.height_threshold, method_values=self.method_values, method_outlier=self.method_outlier, ).skeleton_pruned # skeletonise to remove nibs # Discovered this caused an error when writing tests... # # numpy.core._exceptions._UFuncOutputCastingError: Cannot cast ufunc 'add' output from dtype('int8') to # dtype('bool') with casting... # pruned_skeleton_mask += getSkeleton(self.img, single_skeleton, method="zhang").get_skeleton() pruned_skeleton = getSkeleton(self.img, single_skeleton, method="zhang").get_skeleton() pruned_skeleton_mask += pruned_skeleton.astype(dtype=np.uint8) return pruned_skeleton_mask
[docs] def _prune_by_length( # pylint: disable=too-many-locals # noqa: C901 self, single_skeleton: npt.NDArray, max_length: float ) -> npt.NDArray: """ Remove hanging branches from a skeleton by their length. This is an iterative process as these are a persistent problem in the overall tracing process. Parameters ---------- single_skeleton : npt.NDArray Binary array of the skeleton. max_length : float Maximum length of the branch to prune in nanometers (nm). Returns ------- npt.NDArray Pruned skeleton as binary array. """ # get segments via convolution and removing junctions conv_skeleton = convolve_skeleton(single_skeleton) conv_skeleton[conv_skeleton == 3] = 0 labeled_segments = morphology.label(conv_skeleton.astype(bool)) for segment_idx in range(1, labeled_segments.max() + 1): # get single segment with endpoints==2 segment = np.where(labeled_segments == segment_idx, conv_skeleton, 0) # get segment length ordered_coords = order_branch(np.where(segment != 0, 1, 0), [0, 0]) segment_length = coord_dist(ordered_coords, self.pixel_to_nm_scaling)[-1] # check if endpoint if 2 in segment and segment_length < max_length: # prune single_skeleton[labeled_segments == segment_idx] = 0 return rm_nibs(single_skeleton)
[docs] @staticmethod def _find_branch_ends(coordinates: list) -> list: """ Identify branch ends. This is achieved by iterating through the coordinates and assessing the local pixel area. Ends have only one adjacent pixel. Parameters ---------- coordinates : list List of x, y coordinates of a branch. Returns ------- list List of x, y coordinates of the branch ends. """ branch_ends = [] # Most of the branch ends are just points with one neighbour for x, y in coordinates: if genTracingFuncs.count_and_get_neighbours(x, y, coordinates)[0] == 1: branch_ends.append([x, y]) return branch_ends
[docs] class heightPruning: # pylint: disable=too-many-instance-attributes """ Pruning of branches based on height. Parameters ---------- image : npt.NDArray Original image, typically the height data. skeleton : npt.NDArray Skeleton to prune branches from. max_length : float Maximum length of the branch to prune in nanometres (nm). height_threshold : float Absolute height value to remove branches below in nanometers (nm). method_values : str Method of obtaining the height thresholding values. Options are 'min' (minimum value of the branch), 'median' (median value of the branch) or 'mid' (ordered branch middle coordinate value). method_outlier : str Method to prune branches based on height. Options are 'abs' (below absolute value), 'mean_abs' (below the skeleton mean - absolute threshold) or 'iqr' (below 1.5 * inter-quartile range). """ # numpydoc: ignore=PR01 def __init__( self, image: npt.NDArray, skeleton: npt.NDArray, max_length: float = None, height_threshold: float = None, method_values: str = None, method_outlier: str = None, ) -> None: """ Initialise the class. Parameters ---------- image : npt.NDArray Original image, typically the height data. skeleton : npt.NDArray Skeleton to prune branches from. max_length : float Maximum length of the branch to prune in nanometres (nm). height_threshold : float Absolute height value to remove branches below in nanometers (nm). method_values : str Method of obtaining the height thresholding values. Options are 'min' (minimum value of the branch), 'median' (median value of the branch) or 'mid' (ordered branch middle coordinate value). method_outlier : str Method to prune branches based on height. Options are 'abs' (below absolute value), 'mean_abs' (below the skeleton mean - absolute threshold) or 'iqr' (below 1.5 * inter-quartile range). """ self.image = image self.skeleton = skeleton self.skeleton_convolved = None self.skeleton_branches = None self.skeleton_branches_labelled = None self.max_length = max_length self.height_threshold = height_threshold self.method_values = method_values self.method_outlier = method_outlier self.convolve_skeleton() self.segment_skeleton() self.label_branches() self.skeleton_pruned = self.height_prune()
[docs] def convolve_skeleton(self) -> None: """Convolve skeleton.""" self.skeleton_convolved = convolve_skeleton(self.skeleton)
[docs] def segment_skeleton(self) -> None: """Convolve skeleton and break into segments at nodes/junctions.""" self.skeleton_branches = np.where(self.skeleton_convolved == 3, 0, self.skeleton)
[docs] def label_branches(self) -> None: """Label segmented branches.""" self.skeleton_branches_labelled = morphology.label(self.skeleton_branches)
[docs] def _get_branch_mins(self, segments: npt.NDArray) -> npt.NDArray: """ Collect the minimum height value of each individually labeled branch. Parameters ---------- segments : npt.NDArray Integer labeled array matching the dimensions of the image. Returns ------- npt.NDArray Array of minimum values of each branch index -1. """ return np.array([np.min(self.image[segments == i]) for i in range(1, segments.max() + 1)])
[docs] def _get_branch_medians(self, segments: npt.NDArray) -> npt.NDArray: """ Collect the median height value of each labeled branch. Parameters ---------- segments : npt.NDArray Integer labeled array matching the dimensions of the image. Returns ------- npt.NDArray Array of median values of each branch index -1. """ return np.array([np.median(self.image[segments == i]) for i in range(1, segments.max() + 1)])
[docs] def _get_branch_middles(self, segments: npt.NDArray) -> npt.NDArray: """ Collect the positionally ordered middle height value of each labeled branch. Where the branch has an even amount of points, average the two middle heights. Parameters ---------- segments : npt.NDArray Integer labeled array matching the dimensions of the image. Returns ------- npt.NDArray Array of middle values of each branch. """ branch_middles = np.zeros(segments.max()) for i in range(1, segments.max() + 1): segment = np.where(segments == i, 1, 0) if segment.sum() > 2: # sometimes start is not found ? start = np.argwhere(convolve_skeleton(segment) == 2)[0] ordered_coords = order_branch_from_end(segment, start) # if even no. points, average two middles middle_idx, middle_remainder = (len(ordered_coords) + 1) // 2 - 1, (len(ordered_coords) + 1) % 2 mid_coord = ordered_coords[[middle_idx, middle_idx + middle_remainder]] # height = image[mid_coord[:, 0], mid_coord[:, 1]].mean() height = self.image[mid_coord[:, 0], mid_coord[:, 1]].mean() else: # if 2 points, need to average them height = self.image[segment == 1].mean() branch_middles[i - 1] += height return branch_middles
[docs] @staticmethod def _get_abs_thresh_idx(height_values: npt.NDArray, threshold: float | int) -> npt.NDArray: """ Identify indices of labelled branches whose height values are less than a given threshold. Parameters ---------- height_values : npt.NDArray Array of each branches heights. threshold : float | int Threshold for heights. Returns ------- npt.NDArray Branch indices which are less than threshold. """ return np.asarray(np.where(height_values < threshold))[0] + 1
[docs] @staticmethod def _get_mean_abs_thresh_idx( height_values: npt.NDArray, threshold: float | int, image: npt.NDArray, skeleton: npt.NDArray ) -> npt.NDArray: """ Identify indices of labelled branch whose height values are less than mean skeleton height - absolute threshold. For DNA a threshold of 0.85nm (the depth of the major groove) would ideally remove all segments whose lowest point is < mean(height) - 0.85nm, i.e. 1.15nm. Parameters ---------- height_values : npt.NDArray Array of branches heights. threshold : float | int Threshold to be subtracted from mean heights. image : npt.NDArray Original image of heights. skeleton : npt.NDArray Binary array of skeleton used to identify heights from original image to use. Returns ------- npt.NDArray Branch indices which are less than mean(height) - threshold. """ avg = image[skeleton == 1].mean() print(f"{avg=}") print(f"{(avg-threshold)=}") return np.asarray(np.where(np.asarray(height_values) < (avg - threshold)))[0] + 1
[docs] @staticmethod def _get_iqr_thresh_idx(image: npt.NDArray, segments: npt.NDArray) -> npt.NDArray: """ Identify labelled branch indices whose heights are less than 1.5 x interquartile range of all heights. Parameters ---------- image : npt.NDArray Original image with heights. segments : npt.NDArray Array of skeleton branches. Returns ------- npt.NDArray Branch indices where heights are < 1.5 * inter-quartile range. """ coords = np.argwhere(segments != 0) heights = image[coords[:, 0], coords[:, 1]] # all skel heights else distribution isn't representitive q75, q25 = np.percentile(heights, [75, 25]) iqr = q75 - q25 threshold = q25 - 1.5 * iqr print(f"{q25=}") print(f"{q75=}") print(f"{threshold=}") low_coords = coords[heights < threshold] low_segment_idxs = [] low_segment_mins = [] # iterate through each branch segment and see if any low_coords are in a branch for segment_num in range(1, segments.max() + 1): segment_coords = np.argwhere(segments == segment_num) for low_coord in low_coords: place = np.isin(segment_coords, low_coord).all(axis=1) if place.any(): low_segment_idxs.append(segment_num) low_segment_mins.append(image[segments == segment_num].min()) break return np.array(low_segment_idxs)[np.argsort(low_segment_mins)] # sort in order of ascending mins
[docs] @staticmethod def check_skeleton_one_object(skeleton: npt.NDArray) -> bool: """ Ensure that the skeleton hasn't been broken up upon removing a segment. Parameters ---------- skeleton : npt.NDArray 2D single pixel thick array. Returns ------- bool True or False depending on whether there is 1 or !1 objects. """ skeleton = np.where(skeleton != 0, 1, 0) return morphology.label(skeleton).max() == 1
[docs] def filter_segments(self, segments: npt.NDArray) -> npt.NDArray: """ Identify and remove segments of a skeleton based on the underlying image height. Parameters ---------- segments : npt.NDArray A labelled 2D array of skeleton segments. Returns ------- npt.NDArray The original skeleton without the segments identified by the height criteria. """ # Obtain the height of each branch via the min | median | mid methods if self.method_values == "min": height_values = self._get_branch_mins(segments) elif self.method_values == "median": height_values = self._get_branch_medians(segments) elif self.method_values == "mid": height_values = self._get_branch_middles(segments) # threshold heights to obtain indexes of branches to be removed if self.method_outlier == "abs": idxs = self._get_abs_thresh_idx(height_values, self.height_threshold) elif self.method_outlier == "mean_abs": idxs = self._get_mean_abs_thresh_idx(height_values, self.height_threshold, self.image, self.skeleton) elif self.method_outlier == "iqr": idxs = self._get_iqr_thresh_idx(self.image, segments) # Only remove the bridge if the skeleton remains a single object. skeleton_rtn = self.skeleton.copy() for i in idxs: temp_skel = self.skeleton.copy() temp_skel[segments == i] = 0 if self.check_skeleton_one_object(temp_skel): skeleton_rtn[segments == i] = 0 return skeleton_rtn
# def remove_bridges(self) -> npt.NDArray: # """ # Identify and remove skeleton bridges using the underlying image height. # Bridges cross the skeleton in places they shouldn't and are defined as an internal branch and thus have no # endpoints. They occur due to poor thresholding creating holes in the mask, creating false "bridges" which # misrepresent the skeleton of the molecule. # Returns # ------- # npt.NDArray # A skeleton with internal branches removed by height. # """ # conv = convolve_skeleton(self.skeleton) # # Split the skeleton into branches by removing junctions/nodes and label # nodeless = np.where(conv == 3, 0, conv) # segments = morphology.label(np.where(nodeless != 0, 1, 0)) # # bridges should not concern endpoints so remove these # for i in range(1, segments.max() + 1): # if (conv[segments == i] == 2).any(): # segments[segments == i] = 0 # segments = morphology.label(np.where(segments != 0, 1, 0)) # # filter the segments based on height criteria # return self.filter_segments(segments)
[docs] def height_prune(self) -> npt.NDArray: """ Identify and remove spurious branches (containing endpoints) using the underlying image height. Returns ------- npt.NDArray A skeleton with outer branches removed by height. """ conv = convolve_skeleton(self.skeleton) segments = self._split_skeleton(conv) # height pruning should only concern endpoints so remove internal connections for i in range(1, segments.max() + 1): if not (conv[segments == i] == 2).any(): segments[segments == i] = 0 segments = morphology.label(np.where(segments != 0, 1, 0)) # filter the segments based on height criteria return self.filter_segments(segments)
[docs] @staticmethod def _split_skeleton(skeleton: npt.NDArray) -> npt.NDArray: """ Split the skeleton into branches by removing junctions/nodes and label branches. Parameters ---------- skeleton : npt.NDArray Convolved skeleton to be split. This should have nodes labelled as 3, ends as 2 and all other points as 1. Returns ------- npt.NDArray Removes the junctions (3) and returns all remaining sections as labelled segments. """ nodeless = np.where(skeleton == 3, 0, skeleton) return morphology.label(np.where(nodeless != 0, 1, 0))
[docs] def order_branch_from_end(nodeless: npt.NDArray, start: list, max_length: float = np.inf) -> npt.NDArray: """ Take a linear branch and orders its coordinates starting from a specific endpoint. NB - It may be possible to use np.lexsort() to order points, see topostats.measure.feret.sort_coords() for an example of how to sort by row or column coordinates, which end of the branch this is from probably doesn't matter as one only wants to find the mid-point I think. Parameters ---------- nodeless : npt.NDArray A 2D binary array where there are no crossing pixels. start : list A coordinate to start closest to / at. max_length : float, optional The maximum length to order along the branch, in pixels, by default np.inf. Returns ------- npt.NDArray The input linear branch ordered from the start coordinate. """ dist = 0 # add starting point to ordered array ordered = [] ordered.append(start) nodeless[start[0], start[1]] = 0 # remove from array # iterate to order the rest of the points current_point = ordered[-1] # get last point area, _ = local_area_sum(nodeless, current_point) # look at local area local_next_point = np.argwhere( area.reshape( ( 3, 3, ) ) == 1 ) - (1, 1) dist += np.sqrt(2) if abs(local_next_point).sum() > 1 else 1 while len(local_next_point) != 0 and dist <= max_length: next_point = (current_point + local_next_point)[0] # find where to go next ordered.append(next_point) nodeless[next_point[0], next_point[1]] = 0 # set value to zero current_point = ordered[-1] # get last point area, _ = local_area_sum(nodeless, current_point) # look at local area local_next_point = np.argwhere( area.reshape( ( 3, 3, ) ) == 1 ) - (1, 1) dist += np.sqrt(2) if abs(local_next_point).sum() > 1 else 1 return np.array(ordered)
[docs] def rm_nibs(skeleton): # pylint: disable=too-many-locals """ Remove single pixel branches (nibs) not identified by nearest neighbour algorithms as there may be >2 neighbours. Parameters ---------- skeleton : npt.NDArray A single pixel thick trace. Returns ------- npt.NDArray A skeleton with single pixel nibs removed. """ conv_skel = convolve_skeleton(skeleton) nodes = np.where(conv_skel == 3, 1, 0) labeled_nodes = morphology.label(nodes) nodeless = np.where((conv_skel == 1) | (conv_skel == 2), 1, 0) labeled_nodeless = morphology.label(nodeless) size_1_idxs = [] for node_num in range(1, labeled_nodes.max() + 1): node = np.where(labeled_nodes == node_num, 1, 0) dil = morphology.binary_dilation(node, footprint=np.ones((3, 3))) minus = np.where(dil != node, 1, 0) idxs = labeled_nodeless[minus == 1] idxs = idxs[idxs != 0] for nodeless_num in np.unique(idxs): # if all of the branch is in surrounding node area branch_size = (labeled_nodeless == nodeless_num).sum() branch_idx_in_surr_area = (idxs == nodeless_num).sum() if branch_size == branch_idx_in_surr_area: size_1_idxs.append(nodeless_num) unique, counts = np.unique(np.array(size_1_idxs), return_counts=True) for k, count in enumerate(counts): if count == 1: skeleton[labeled_nodeless == unique[k]] = 0 return skeleton
[docs] def local_area_sum(img: npt.NDArray, point: list | tuple | npt.NDArray) -> tuple: """ Evaluate the local area around a point in a binary map. Parameters ---------- img : npt.NDArray Binary array of image. point : list | tuple | npt.NDArray Coordinates of a point within the binary_map. Returns ------- tuple Tuple consisting of an array values of the local coordinates around the point and the number of neighbours around the point. """ if img[point[0], point[1]] > 1: raise ValueError("binary_map is not binary!") # Capture if point is on the top or left edge or array try: local_pixels = img[point[0] - 1 : point[0] + 2, point[1] - 1 : point[1] + 2].flatten() except IndexError as exc: raise IndexError("Point can not be on the edge of an array.") from exc # Above does not capture points on right or bottom since slicing arrays beyond their indexes simply extends them # Therefore check that we have an array of length 9 if local_pixels.shape[0] == 9: local_pixels[4] = 0 # ensure centre is 0 if local_pixels.sum() <= 8: return local_pixels, local_pixels.sum() raise ValueError("'binary_map' is not binary!") raise IndexError("'point' is on right or bottom edge of 'binary_map'")