"""Skeletonize molecules."""
import logging
from collections.abc import Callable
import numpy as np
import numpy.typing as npt
from skimage.morphology import medial_axis, skeletonize, thin
from topostats.logs.logs import LOGGER_NAME
LOGGER = logging.getLogger(LOGGER_NAME)
[docs]
class getSkeleton: # pylint: disable=too-few-public-methods
"""
Class skeletonising images.
Parameters
----------
image : npt.NDArray
Image used to generate the mask.
mask : npt.NDArray
Binary mask of features.
method : str
Method for skeletonizing. Options 'zhang' (default), 'lee', 'medial_axis', 'thin' and 'topostats'.
height_bias : float
Ratio of lowest intensity (height) pixels to total pixels fitting the skeletonisation criteria. 1 is all pixels
smiilar to Zhang.
"""
def __init__(self, image: npt.NDArray, mask: npt.NDArray, method: str = "zhang", height_bias: float = 0.6):
"""
Initialise the class.
This is a thin wrapper to the methods provided by the `skimage.morphology
<https://scikit-image.org/docs/stable/api/skimage.morphology.html?highlight=skeletonize>`_
module. See also the `examples
<https://scikit-image.org/docs/stable/auto_examples/edges/plot_skeleton.html>_
Parameters
----------
image : npt.NDArray
Image used to generate the mask.
mask : npt.NDArray
Binary mask of features.
method : str
Method for skeletonizing. Options 'zhang' (default), 'lee', 'medial_axis', 'thin' and 'topostats'.
height_bias : float
Ratio of lowest intensity (height) pixels to total pixels fitting the skeletonisation criteria. 1 is all
pixels smiilar to Zhang.
"""
# Q What benefit is there to having a class getSkeleton over the get_skeleton() function? Ostensibly the class
# is doing only one thing, we don't need to change state/modify anything here. Beyond encapsulating all
# functions in a single class this feels like overkill.
self.image = image
self.mask = mask
self.method = method
self.height_bias = height_bias
[docs]
def get_skeleton(self) -> npt.NDArray:
"""
Skeletonise molecules.
Returns
-------
npt.NDArray
Skeletonised version of the binary mask (possibly using criteria from the image).
"""
return self._get_skeletonize()
[docs]
def _get_skeletonize(self) -> Callable:
"""
Determine which skeletonise method to use.
Returns
-------
Callable
Returns the function appropriate for the required skeletonizing method.
"""
if self.method == "zhang":
return self._skeletonize_zhang(mask=self.mask).astype(np.int8)
if self.method == "lee":
return self._skeletonize_lee(mask=self.mask).astype(np.int8)
if self.method == "medial_axis":
return self._skeletonize_medial_axis(mask=self.mask).astype(np.int8)
if self.method == "thin":
return self._skeletonize_thin(mask=self.mask).astype(np.int8)
if self.method == "topostats":
return self._skeletonize_topostats(image=self.image, mask=self.mask, height_bias=self.height_bias).astype(
np.int8
)
raise ValueError(self.method)
[docs]
@staticmethod
def _skeletonize_zhang(mask: npt.NDArray) -> npt.NDArray:
"""
Use scikit-image implementation of the Zhang skeletonisation method.
Parameters
----------
mask : npt.NDArray
Binary array to skeletonise.
Returns
-------
npt.NDArray
Mask array reduced to a single pixel thickness.
"""
return skeletonize(mask, method="zhang")
[docs]
@staticmethod
def _skeletonize_lee(mask: npt.NDArray) -> npt.NDArray:
"""
Use scikit-image implementation of the Lee skeletonisation method.
Parameters
----------
mask : npt.NDArray
Binary array to skeletonise.
Returns
-------
npt.NDArray
Mask array reduced to a single pixel thickness.
"""
return skeletonize(mask, method="lee")
[docs]
@staticmethod
def _skeletonize_thin(mask: npt.NDArray) -> npt.NDArray:
"""
Use scikit-image implementation of the thinning skeletonisation method.
Parameters
----------
mask : npt.NDArray
Binary array to skeletonise.
Returns
-------
npt.NDArray
Mask array reduced to a single pixel thickness.
"""
return thin(mask)
[docs]
@staticmethod
def _skeletonize_topostats(image: npt.NDArray, mask: npt.NDArray, height_bias: float = 0.6) -> npt.NDArray:
"""
Use scikit-image implementation of the Zhang skeletonisation method.
This method is based on Zhang's method but produces different results (less branches but slightly less
accurate).
Parameters
----------
image : npt.NDArray
Original image with heights.
mask : npt.NDArray
Binary array to skeletonise.
height_bias : float
Ratio of lowest intensity (height) pixels to total pixels fitting the skeletonisation criteria. 1 is all
pixels smiilar to Zhang.
Returns
-------
npt.NDArray
Masked array reduced to a single pixel thickness.
"""
return topostatsSkeletonize(image, mask, height_bias).do_skeletonising()
[docs]
class topostatsSkeletonize: # pylint: disable=too-many-instance-attributes
"""
Skeletonise a binary array following Zhang's algorithm (Zhang and Suen, 1984).
Modifications are made to the published algorithm during the removal step to remove a fraction of the smallest pixel
values opposed to all of them in the aforementioned algorithm. All operations are performed on the mask entered.
Parameters
----------
image : npt.NDArray
Original 2D image containing the height data.
mask : npt.NDArray
Binary image containing the object to be skeletonised. Dimensions should match those of 'image'.
height_bias : float
Ratio of lowest intensity (height) pixels to total pixels fitting the skeletonisation criteria. 1 is all pixels
smiilar to Zhang.
"""
def __init__(self, image: npt.NDArray, mask: npt.NDArray, height_bias: float = 0.6):
"""
Initialise the class.
Parameters
----------
image : npt.NDArray
Original 2D image containing the height data.
mask : npt.NDArray
Binary image containing the object to be skeletonised. Dimensions should match those of 'image'.
height_bias : float
Ratio of lowest intensity (height) pixels to total pixels fitting the skeletonisation criteria. 1 is all
pixels smiilar to Zhang.
"""
self.image = image
self.mask = mask.copy()
self.height_bias = height_bias
self.skeleton_converged = False
self.p2 = None
self.p3 = None
self.p4 = None
self.p5 = None
self.p6 = None
self.p7 = None
self.p8 = None
self.p9 = None
self.counter = 0
[docs]
def do_skeletonising(self) -> npt.NDArray:
"""
Perform skeletonisation.
Returns
-------
npt.NDArray
The single pixel thick, skeletonised array.
"""
while not self.skeleton_converged:
self._do_skeletonising_iteration()
# When skeleton converged do an additional iteration of thinning to remove hanging points
self.final_skeletonisation_iteration()
self.mask = getSkeleton(
image=self.image, mask=self.mask, method="zhang"
).get_skeleton() # not sure if this is needed?
return self.mask
[docs]
def _do_skeletonising_iteration(self) -> None:
"""
Obtain the local binary pixel environment and assess the local pixel values.
This determines whether to delete a point according to the Zhang algorithm.
Then removes ratio of lowest intensity (height) pixels to total pixels fitting the skeletonisation criteria. 1
is all pixels smiilar to Zhang.
"""
skel_img = self.mask.copy()
pixels_to_delete = []
# Sub-iteration 1 - binary check
mask_coordinates = np.argwhere(self.mask == 1).tolist()
for point in mask_coordinates:
if self._delete_pixel_subit1(point):
pixels_to_delete.append(point)
# remove points based on height (lowest height_bias%)
pixels_to_delete = np.asarray(pixels_to_delete) # turn into array
if pixels_to_delete.shape != (0,): # ensure array not empty
skel_img[pixels_to_delete[:, 0], pixels_to_delete[:, 1]] = 2
heights = self.image[pixels_to_delete[:, 0], pixels_to_delete[:, 1]] # get heights of pixels
height_sort_idx = self.sort_and_shuffle(heights)[1][
: int(np.ceil(len(heights) * self.height_bias))
] # idx of lowest height_bias%
self.mask[pixels_to_delete[height_sort_idx, 0], pixels_to_delete[height_sort_idx, 1]] = (
0 # remove lowest height_bias%
)
pixels_to_delete = []
# Sub-iteration 2 - binary check
mask_coordinates = np.argwhere(self.mask == 1).tolist()
for point in mask_coordinates:
if self._delete_pixel_subit2(point):
pixels_to_delete.append(point)
# remove points based on height (lowest height_bias%)
pixels_to_delete = np.asarray(pixels_to_delete)
if pixels_to_delete.shape != (0,):
skel_img[pixels_to_delete[:, 0], pixels_to_delete[:, 1]] = 3
heights = self.image[pixels_to_delete[:, 0], pixels_to_delete[:, 1]]
height_sort_idx = self.sort_and_shuffle(heights)[1][
: int(np.ceil(len(heights) * self.height_bias))
] # idx of lowest height_bias%
self.mask[pixels_to_delete[height_sort_idx, 0], pixels_to_delete[height_sort_idx, 1]] = (
0 # remove lowest height_bias%
)
if len(pixels_to_delete) == 0:
self.skeleton_converged = True
[docs]
def _delete_pixel_subit1(self, point: list) -> bool:
"""
Check whether a single point (P1) should be deleted based on its local binary environment.
(a) 2 ≤ B(P1) ≤ 6, where B(P1) is the number of non-zero neighbours of P1.
(b) A(P1) = 1, where A(P1) is the # of 01's around P1.
(C) P2 * P4 * P6 = 0
(d) P4 * P6 * P8 = 0
Parameters
----------
point : list
List of [x, y] coordinate positions.
Returns
-------
bool
Indicates whether to delete depending on whether the surrounding points have met the criteria of the binary
thin a, b returncount, c and d checks below.
"""
self.p7, self.p8, self.p9, self.p6, self.p2, self.p5, self.p4, self.p3 = self.get_local_pixels_binary(
self.mask, point[0], point[1]
)
return (
self._binary_thin_check_a()
and self._binary_thin_check_b_returncount() == 1
# c and d remove only north-west corner points and south-east boundary points.
and self._binary_thin_check_c()
and self._binary_thin_check_d()
)
[docs]
def _delete_pixel_subit2(self, point: list) -> bool:
"""
Check whether a single point (P1) should be deleted based on its local binary environment.
(a) 2 ≤ B(P1) ≤ 6, where B(P1) is the number of non-zero neighbours of P1.
(b) A(P1) = 1, where A(P1) is the # of 01's around P1.
(c') P2 * P4 * P8 = 0
(d') P2 * P6 * P8 = 0
Parameters
----------
point : list
List of [x, y] coordinate positions.
Returns
-------
bool
Whether surrounding points have met the criteria of the binary thin a, b returncount, csharp and dsharp
checks below.
"""
self.p7, self.p8, self.p9, self.p6, self.p2, self.p5, self.p4, self.p3 = self.get_local_pixels_binary(
self.mask, point[0], point[1]
)
# Add in generic code here to protect high points from being deleted
return (
self._binary_thin_check_a()
and self._binary_thin_check_b_returncount() == 1
# c' and d' remove only north-west boundary points or south-east corner points.
and self._binary_thin_check_csharp()
and self._binary_thin_check_dsharp()
)
[docs]
def _binary_thin_check_a(self) -> bool:
"""
Check the surrounding area to see if the point lies on the edge of the grain.
Condition A protects the endpoints (which will be < 2)
Returns
-------
bool
If point lies on edge of graph and isn't an endpoint.
"""
return 2 <= self.p2 + self.p3 + self.p4 + self.p5 + self.p6 + self.p7 + self.p8 + self.p9 <= 6
[docs]
def _binary_thin_check_b_returncount(self) -> int:
"""
Count local area 01's in order around P1.
??? What does this mean?
Returns
-------
int
The number of 01's around P1.
"""
return sum(
[
[self.p2, self.p3] == [0, 1],
[self.p3, self.p4] == [0, 1],
[self.p4, self.p5] == [0, 1],
[self.p5, self.p6] == [0, 1],
[self.p6, self.p7] == [0, 1],
[self.p7, self.p8] == [0, 1],
[self.p8, self.p9] == [0, 1],
[self.p9, self.p2] == [0, 1],
]
)
[docs]
def _binary_thin_check_c(self) -> bool:
"""
Check if p2, p4 or p6 is 0.
Returns
-------
bool
If p2, p4 or p6 is 0.
"""
return self.p2 * self.p4 * self.p6 == 0
[docs]
def _binary_thin_check_d(self) -> bool:
"""
Check if p4, p6 or p8 is 0.
Returns
-------
bool
If p4, p6 or p8 is 0.
"""
return self.p4 * self.p6 * self.p8 == 0
[docs]
def _binary_thin_check_csharp(self) -> bool:
"""
Check if p2, p4 or p8 is 0.
Returns
-------
bool
If p2, p4 or p8 is 0.
"""
return self.p2 * self.p4 * self.p8 == 0
[docs]
def _binary_thin_check_dsharp(self) -> bool:
"""
Check if p2, p6 or p8 is 0.
Returns
-------
bool
If p2, p6 or p8 is 0.
"""
return self.p2 * self.p6 * self.p8 == 0
[docs]
def final_skeletonisation_iteration(self) -> None:
"""
Remove "hanging" pixels.
Examples of such pixels are:
[0, 0, 0] [0, 1, 0] [0, 0, 0]
[0, 1, 1] [0, 1, 1] [0, 1, 1]
case 1: [0, 1, 0] or case 2: [0, 1, 0] or case 3: [1, 1, 0]
This is useful for the future functions that rely on local pixel environment
to make assessments about the overall shape/structure of traces.
"""
remaining_coordinates = np.argwhere(self.mask).tolist()
for x, y in remaining_coordinates:
self.p7, self.p8, self.p9, self.p6, self.p2, self.p5, self.p4, self.p3 = self.get_local_pixels_binary(
self.mask, x, y
)
# Checks for case 1 and 3 pixels
if (
self._binary_thin_check_b_returncount() == 2
and self._binary_final_thin_check_a()
and not self.binary_thin_check_diag()
):
self.mask[x, y] = 0
# Checks for case 2 pixels
elif self._binary_thin_check_b_returncount() == 3 and self._binary_final_thin_check_b():
self.mask[x, y] = 0
[docs]
def _binary_final_thin_check_a(self) -> bool:
"""
Assess if local area has 4-connectivity.
Returns
-------
bool
Logical indicator of whether if any neighbours of the 4-connections have a near pixel.
"""
return 1 in (self.p2 * self.p4, self.p4 * self.p6, self.p6 * self.p8, self.p8 * self.p2)
[docs]
def _binary_final_thin_check_b(self) -> bool:
"""
Assess if local area 4-connectivity is connected to multiple branches.
Returns
-------
bool
Logical indicator of whether if any neighbours of the 4-connections have a near pixel.
"""
return 1 in (
self.p2 * self.p4 * self.p6,
self.p4 * self.p6 * self.p8,
self.p6 * self.p8 * self.p2,
self.p8 * self.p2 * self.p4,
)
[docs]
def binary_thin_check_diag(self) -> bool:
"""
Check if opposite corner diagonals are present.
Returns
-------
bool
Whether a diagonal exists.
"""
return 1 in (self.p7 * self.p3, self.p5 * self.p9)
[docs]
@staticmethod
def get_local_pixels_binary(binary_map: npt.NDArray, x: int, y: int) -> npt.NDArray:
"""
Value of pixels in the local 8-connectivity area around the coordinate (P1) described by x and y.
P1 must not lie on the edge of the binary map.
[[p7, p8, p9], [[0,1,2],
[p6, P1, p2], -> [3,4,5], -> [0,1,2,3,5,6,7,8]
[p5, p4, p3]] [6,7,8]]
delete P1 to only get local area.
Parameters
----------
binary_map : npt.NDArray
Binary mask of image.
x : int
X coordinate within the binary map.
y : int
Y coordinate within the binary map.
Returns
-------
npt.NDArray
Flattened 8-long array describing the values in the binary map around the x,y point.
"""
local_pixels = binary_map[x - 1 : x + 2, y - 1 : y + 2].flatten()
return np.delete(local_pixels, 4)
[docs]
@staticmethod
def sort_and_shuffle(arr: npt.NDArray, seed: int = 23790101) -> tuple[npt.NDArray, npt.NDArray]:
"""
Sort array in ascending order and shuffle the order of identical values are the same.
Parameters
----------
arr : npt.NDArray
A flattened (1D) array.
seed : int
Seed for random number generator.
Returns
-------
npt.NDArray
An ascending order array where identical value orders are also shuffled.
npt.NDArray
An ascending order index array of above where identical value orders are also shuffled.
"""
# Find unique values
unique_values_r = np.unique(arr)
rng = np.random.default_rng(seed)
# Shuffle the order of elements with the same value
sorted_and_shuffled_indices: list = []
for val in unique_values_r:
indices = np.where(arr == val)[0]
rng.shuffle(indices)
sorted_and_shuffled_indices.extend(indices)
# Rearrange the sorted array according to shuffled indices
sorted_and_shuffled_arr: list = arr[sorted_and_shuffled_indices]
return sorted_and_shuffled_arr, sorted_and_shuffled_indices