"""Perform Crossing Region Processing and Analysis."""
from __future__ import annotations
import logging
from itertools import combinations
from typing import TypedDict
import networkx as nx
import numpy as np
import numpy.typing as npt
import pandas as pd
from scipy.ndimage import binary_dilation
from scipy.signal import argrelextrema
from skimage.morphology import label
from topostats.logs.logs import LOGGER_NAME
from topostats.measure.geometry import (
calculate_shortest_branch_distances,
connect_best_matches,
find_branches_for_nodes,
)
from topostats.tracing.pruning import prune_skeleton
from topostats.tracing.skeletonize import getSkeleton
from topostats.tracing.tracingfuncs import order_branch, order_branch_from_start
from topostats.utils import ResolutionError, convolve_skeleton
LOGGER = logging.getLogger(LOGGER_NAME)
# pylint: disable=too-many-arguments
# pylint: disable=too-many-branches
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-lines
# pylint: disable=too-many-locals
# pylint: disable=too-many-nested-blocks
# pylint: disable=too-many-public-methods
# pylint: disable=too-many-statements
[docs]
class NodeDict(TypedDict):
"""Dictionary containing the node information."""
error: bool
pixel_to_nm_scaling: np.float64
branch_stats: dict[int, MatchedBranch] | None
node_coords: npt.NDArray[np.int32] | None
confidence: np.float64 | None
[docs]
class MatchedBranch(TypedDict):
"""
Dictionary containing the matched branches.
matched_branches: dict[int, dict[str, npt.NDArray[np.number]]]
Dictionary where the key is the index of the pair and the value is a dictionary containing the following
keys:
- "ordered_coords" : npt.NDArray[np.int32]. The ordered coordinates of the branch.
- "heights" : npt.NDArray[np.number]. Heights of the branch coordinates.
- "distances" : npt.NDArray[np.number]. Distances of the branch coordinates.
- "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branch.
- "angles" : np.float64. The initial direction angle of the branch, added in later steps.
"""
ordered_coords: npt.NDArray[np.int32]
heights: npt.NDArray[np.number]
distances: npt.NDArray[np.number]
fwhm: dict[str, np.float64 | tuple[np.float64]]
angles: np.float64 | None
[docs]
class ImageDict(TypedDict):
"""Dictionary containing the image information."""
nodes: dict[str, dict[str, npt.NDArray[np.int32]]]
grain: dict[str, npt.NDArray[np.int32] | dict[str, npt.NDArray[np.int32]]]
[docs]
class nodeStats:
"""
Class containing methods to find and analyse the nodes/crossings within a grain.
Parameters
----------
filename : str
The name of the file being processed. For logging purposes.
image : npt.npt.NDArray
The array of pixels.
mask : npt.npt.NDArray
The binary segmentation mask.
smoothed_mask : npt.NDArray
A smoothed version of the bianary segmentation mask.
skeleton : npt.NDArray
A binary single-pixel wide mask of objects in the 'image'.
pixel_to_nm_scaling : np.float32
The pixel to nm scaling factor.
n_grain : int
The grain number.
node_joining_length : float
The length over which to join skeletal intersections to be counted as one crossing.
node_joining_length : float
The distance over which to join nearby odd-branched nodes.
node_extend_dist : float
The distance under which to join odd-branched node regions.
branch_pairing_length : float
The length from the crossing point to pair and trace, obtaining FWHM's.
pair_odd_branches : bool
Whether to try and pair odd-branched nodes.
"""
def __init__(
self,
filename: str,
image: npt.NDArray,
mask: npt.NDArray,
smoothed_mask: npt.NDArray,
skeleton: npt.NDArray,
pixel_to_nm_scaling: np.float64,
n_grain: int,
node_joining_length: float,
node_extend_dist: float,
branch_pairing_length: float,
pair_odd_branches: bool,
) -> None:
"""
Initialise the nodeStats class.
Parameters
----------
filename : str
The name of the file being processed. For logging purposes.
image : npt.NDArray
The array of pixels.
mask : npt.NDArray
The binary segmentation mask.
smoothed_mask : npt.NDArray
A smoothed version of the bianary segmentation mask.
skeleton : npt.NDArray
A binary single-pixel wide mask of objects in the 'image'.
pixel_to_nm_scaling : float
The pixel to nm scaling factor.
n_grain : int
The grain number.
node_joining_length : float
The length over which to join skeletal intersections to be counted as one crossing.
node_joining_length : float
The distance over which to join nearby odd-branched nodes.
node_extend_dist : float
The distance under which to join odd-branched node regions.
branch_pairing_length : float
The length from the crossing point to pair and trace, obtaining FWHM's.
pair_odd_branches : bool
Whether to try and pair odd-branched nodes.
"""
self.filename = filename
self.image = image
self.mask = mask
self.smoothed_mask = smoothed_mask # only used to average traces
self.skeleton = skeleton
self.pixel_to_nm_scaling = pixel_to_nm_scaling
self.n_grain = n_grain
self.node_joining_length = node_joining_length
self.node_extend_dist = node_extend_dist / self.pixel_to_nm_scaling
self.branch_pairing_length = branch_pairing_length
self.pair_odd_branches = pair_odd_branches
self.conv_skelly = np.zeros_like(self.skeleton)
self.connected_nodes = np.zeros_like(self.skeleton)
self.all_connected_nodes = np.zeros_like(self.skeleton)
self.whole_skel_graph: nx.classes.graph.Graph | None = None
self.node_centre_mask = np.zeros_like(self.skeleton)
self.metrics = {
"num_crossings": np.int64(0),
"avg_crossing_confidence": None,
"min_crossing_confidence": None,
}
self.node_dicts: dict[str, NodeDict] = {}
self.image_dict: ImageDict = {
"nodes": {},
"grain": {
"grain_image": self.image,
"grain_mask": self.mask,
"grain_skeleton": self.skeleton,
},
}
self.full_dict = {}
self.mol_coords = {}
self.visuals = {}
self.all_visuals_img = None
[docs]
def get_node_stats(self) -> tuple[dict, dict]:
"""
Run the workflow to obtain the node statistics.
.. code-block:: RST
node_dict key structure: <grain_number>
└-> <node_number>
|-> 'error'
└-> 'node_coords'
└-> 'branch_stats'
└-> <branch_number>
|-> 'ordered_coords'
|-> 'heights'
|-> 'gaussian_fit'
|-> 'fwhm'
└-> 'angles'
image_dict key structure: 'nodes'
<node_number>
|-> 'node_area_skeleton'
|-> 'node_branch_mask'
└-> 'node_avg_mask
'grain'
|-> 'grain_image'
|-> 'grain_mask'
└-> 'grain_skeleton'
Returns
-------
tuple[dict, dict]
Dictionaries of the node_information and images.
"""
LOGGER.debug(f"Node Stats - Processing Grain: {self.n_grain}")
self.conv_skelly = convolve_skeleton(self.skeleton)
if len(self.conv_skelly[self.conv_skelly == 3]) != 0: # check if any nodes
LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} contains crossings.")
# convolve to see crossing and end points
# self.conv_skelly = self.tidy_branches(self.conv_skelly, self.image)
# reset skeleton var as tidy branches may have modified it
self.skeleton = np.where(self.conv_skelly != 0, 1, 0)
self.image_dict["grain"]["grain_skeleton"] = self.skeleton
# get graph of skeleton
self.whole_skel_graph = self.skeleton_image_to_graph(self.skeleton)
# connect the close nodes
LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} connecting close nodes.")
self.connected_nodes = self.connect_close_nodes(self.conv_skelly, node_width=self.node_joining_length)
# connect the odd-branch nodes
self.connected_nodes = self.connect_extended_nodes_nearest(
self.connected_nodes, node_extend_dist=self.node_extend_dist
)
# obtain a mask of node centers and their count
self.node_centre_mask = self.highlight_node_centres(self.connected_nodes)
# Begin the hefty crossing analysis
LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} analysing found crossings.")
self.analyse_nodes(max_branch_length=self.branch_pairing_length)
self.compile_metrics()
else:
LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} has no crossings.")
return self.node_dicts, self.image_dict
# self.all_visuals_img = dnaTrace.concat_images_in_dict(self.image.shape, self.visuals)
[docs]
@staticmethod
def skeleton_image_to_graph(skeleton: npt.NDArray) -> nx.classes.graph.Graph:
"""
Convert a skeletonised mask into a Graph representation.
Graphs conserve the coordinates via the node label.
Parameters
----------
skeleton : npt.NDArray
A binary single-pixel wide mask, or result from conv_skelly().
Returns
-------
nx.classes.graph.Graph
A networkX graph connecting the pixels in the skeleton to their neighbours.
"""
skeImPos = np.argwhere(skeleton).T
g = nx.Graph()
neigh = np.array([[0, 1], [0, -1], [1, 0], [-1, 0], [1, 1], [1, -1], [-1, 1], [-1, -1]])
for idx in range(skeImPos[0].shape[0]):
for neighIdx in range(neigh.shape[0]):
curNeighPos = skeImPos[:, idx] + neigh[neighIdx]
if np.any(curNeighPos < 0) or np.any(curNeighPos >= skeleton.shape):
continue
if skeleton[curNeighPos[0], curNeighPos[1]] > 0:
idx_coord = skeImPos[0, idx], skeImPos[1, idx]
curNeigh_coord = curNeighPos[0], curNeighPos[1]
# assign lower weight to nodes if not a binary image
if skeleton[idx_coord] == 3 and skeleton[curNeigh_coord] == 3:
weight = 0
else:
weight = 1
g.add_edge(idx_coord, curNeigh_coord, weight=weight)
g.graph["physicalPos"] = skeImPos.T
return g
[docs]
@staticmethod
def graph_to_skeleton_image(g: nx.Graph, im_shape: tuple[int]) -> npt.NDArray:
"""
Convert the skeleton graph back to a binary image.
Parameters
----------
g : nx.Graph
Graph with coordinates as node labels.
im_shape : tuple[int]
The shape of the image to dump.
Returns
-------
npt.NDArray
Skeleton binary image from the graph representation.
"""
im = np.zeros(im_shape)
for node in g:
im[node] = 1
return im
[docs]
def tidy_branches(self, connect_node_mask: npt.NDArray, image: npt.NDArray) -> npt.NDArray:
"""
Wrangle distant connected nodes back towards the main cluster.
Works by filling and reskeletonising soely the node areas.
Parameters
----------
connect_node_mask : npt.NDArray
The connected node mask - a skeleton where node regions = 3, endpoints = 2, and skeleton = 1.
image : npt.NDArray
The intensity image.
Returns
-------
npt.NDArray
The wrangled connected_node_mask.
"""
new_skeleton = np.where(connect_node_mask != 0, 1, 0)
labeled_nodes = label(np.where(connect_node_mask == 3, 1, 0))
for node_num in range(1, labeled_nodes.max() + 1):
solo_node = np.where(labeled_nodes == node_num, 1, 0)
coords = np.argwhere(solo_node == 1)
node_centre = coords.mean(axis=0).astype(np.int32)
node_wid = coords[:, 0].max() - coords[:, 0].min() + 2 # +2 so always 2 by default
node_len = coords[:, 1].max() - coords[:, 1].min() + 2 # +2 so always 2 by default
overflow = int(10 / self.pixel_to_nm_scaling) if int(10 / self.pixel_to_nm_scaling) != 0 else 1
# grain mask fill
new_skeleton[
node_centre[0] - node_wid // 2 - overflow : node_centre[0] + node_wid // 2 + overflow,
node_centre[1] - node_len // 2 - overflow : node_centre[1] + node_len // 2 + overflow,
] = self.mask[
node_centre[0] - node_wid // 2 - overflow : node_centre[0] + node_wid // 2 + overflow,
node_centre[1] - node_len // 2 - overflow : node_centre[1] + node_len // 2 + overflow,
]
# remove any artifacts of the grain caught in the overflow areas
new_skeleton = self.keep_biggest_object(new_skeleton)
# Re-skeletonise
new_skeleton = getSkeleton(image, new_skeleton, method="topostats", height_bias=0.6).get_skeleton()
# new_skeleton = pruneSkeleton(image, new_skeleton).prune_skeleton(
# {"method": "topostats", "max_length": -1}
# )
new_skeleton = prune_skeleton(
image, new_skeleton, self.pixel_to_nm_scaling, **{"method": "topostats", "max_length": -1}
)
# cleanup around nibs
new_skeleton = getSkeleton(image, new_skeleton, method="zhang").get_skeleton()
# might also need to remove segments that have squares connected
return convolve_skeleton(new_skeleton)
[docs]
@staticmethod
def keep_biggest_object(mask: npt.NDArray) -> npt.NDArray:
"""
Retain the largest object in a binary mask.
Parameters
----------
mask : npt.NDArray
Binary mask.
Returns
-------
npt.NDArray
A binary mask with only one object.
"""
labelled_mask = label(mask)
idxs, counts = np.unique(mask, return_counts=True)
try:
max_idx = idxs[np.argmax(counts[1:]) + 1]
return np.where(labelled_mask == max_idx, 1, 0)
except ValueError as e:
LOGGER.debug(f"{e}: mask is empty.")
return mask
[docs]
def connect_close_nodes(self, conv_skelly: npt.NDArray, node_width: float = 2.85) -> npt.NDArray:
"""
Connect nodes within the 'node_width' boundary distance.
This labels them as part of the same node.
Parameters
----------
conv_skelly : npt.NDArray
A labeled skeleton image with skeleton = 1, endpoints = 2, crossing points =3.
node_width : float
The width of the dna in the grain, used to connect close nodes.
Returns
-------
np.ndarray
The skeleton (label=1) with close nodes connected (label=3).
"""
self.connected_nodes = conv_skelly.copy()
nodeless = conv_skelly.copy()
nodeless[(nodeless == 3) | (nodeless == 2)] = 0 # remove node & termini points
nodeless_labels = label(nodeless)
for i in range(1, nodeless_labels.max() + 1):
if nodeless[nodeless_labels == i].size < (node_width / self.pixel_to_nm_scaling):
# maybe also need to select based on height? and also ensure small branches classified
self.connected_nodes[nodeless_labels == i] = 3
return self.connected_nodes
[docs]
def highlight_node_centres(self, mask: npt.NDArray) -> npt.NDArray:
"""
Calculate the node centres based on height and re-plot on the mask.
Parameters
----------
mask : npt.NDArray
2D array with background = 0, skeleton = 1, endpoints = 2, node_centres = 3.
Returns
-------
npt.NDArray
2D array with the highest node coordinate for each node labeled as 3.
"""
small_node_mask = mask.copy()
small_node_mask[mask == 3] = 1 # remap nodes to skeleton
big_nodes = mask.copy()
big_nodes = np.where(mask == 3, 1, 0) # remove non-nodes & set nodes to 1
big_node_mask = label(big_nodes)
for i in np.delete(np.unique(big_node_mask), 0): # get node indices
centre = np.unravel_index((self.image * (big_node_mask == i).astype(int)).argmax(), self.image.shape)
small_node_mask[centre] = 3
return small_node_mask
[docs]
def connect_extended_nodes_nearest(
self, connected_nodes: npt.NDArray, node_extend_dist: float = -1
) -> npt.NDArray[np.int32]:
"""
Extend the odd branched nodes to other odd branched nodes within the 'extend_dist' threshold.
Parameters
----------
connected_nodes : npt.NDArray
A 2D array representing the network with background = 0, skeleton = 1, endpoints = 2,
node_centres = 3.
node_extend_dist : int | float, optional
The distance over which to connect odd-branched nodes, by default -1 for no-limit.
Returns
-------
npt.NDArray[np.int32]
Connected nodes array with odd-branched nodes connected.
"""
just_nodes = np.where(connected_nodes == 3, 1, 0) # remove branches & termini points
labelled_nodes = label(just_nodes)
just_branches = np.where(connected_nodes == 1, 1, 0) # remove node & termini points
just_branches[connected_nodes == 1] = labelled_nodes.max() + 1
labelled_branches = label(just_branches)
nodes_with_branch_starting_coords = find_branches_for_nodes(
network_array_representation=connected_nodes,
labelled_nodes=labelled_nodes,
labelled_branches=labelled_branches,
)
# If there is only one node, then there is no need to connect the nodes since there is nothing to
# connect it to. Return the original connected_nodes instead.
if len(nodes_with_branch_starting_coords) <= 1:
self.connected_nodes = connected_nodes
return self.connected_nodes
assert self.whole_skel_graph is not None, "Whole skeleton graph is not defined." # for type safety
shortest_node_dists, shortest_dists_branch_idxs, _shortest_dist_coords = calculate_shortest_branch_distances(
nodes_with_branch_starting_coords=nodes_with_branch_starting_coords,
whole_skeleton_graph=self.whole_skel_graph,
)
# Matches is an Nx2 numpy array of indexes of the best matching nodes.
# Eg: np.array([[1, 0], [2, 3]]) means that the best matching nodes are
# node 1 and node 0, and node 2 and node 3.
matches: npt.NDArray[np.int32] = self.best_matches(shortest_node_dists, max_weight_matching=False)
# Connect the nodes by their best matches, using the shortest distances between their branch starts.
connected_nodes = connect_best_matches(
network_array_representation=connected_nodes,
whole_skeleton_graph=self.whole_skel_graph,
match_indexes=matches,
shortest_distances_between_nodes=shortest_node_dists,
shortest_distances_branch_indexes=shortest_dists_branch_idxs,
emanating_branch_starts_by_node=nodes_with_branch_starting_coords,
extend_distance=node_extend_dist,
)
self.connected_nodes = connected_nodes
return self.connected_nodes
[docs]
@staticmethod
def find_branch_starts(reduced_node_image: npt.NDArray) -> npt.NDArray:
"""
Find the coordinates where the branches connect to the node region through binary dilation of the node.
Parameters
----------
reduced_node_image : npt.NDArray
A 2D numpy array containing a single node region (=3) and its connected branches (=1).
Returns
-------
npt.NDArray
Coordinate array of pixels next to crossing points (=3 in input).
"""
node = np.where(reduced_node_image == 3, 1, 0)
nodeless = np.where(reduced_node_image == 1, 1, 0)
thick_node = binary_dilation(node, structure=np.ones((3, 3)))
return np.argwhere(thick_node * nodeless == 1)
# pylint: disable=too-many-locals
[docs]
def analyse_nodes(self, max_branch_length: float = 20) -> None:
"""
Obtain the main analyses for the nodes of a single molecule along the 'max_branch_length'(nm) from the node.
Parameters
----------
max_branch_length : float
The side length of the box around the node to analyse (in nm).
"""
# Get coordinates of nodes
# This is a numpy array of coords, shape Nx2
assert self.node_centre_mask is not None, "Node centre mask is not defined."
node_coords: npt.NDArray[np.int32] = np.argwhere(self.node_centre_mask.copy() == 3)
# Check whether average trace resides inside the grain mask
# Checks if we dilate the skeleton once or twice, then all the pixels should fit in the grain mask
dilate = binary_dilation(self.skeleton, iterations=2)
# This flag determines whether to use average of 3 traces in calculation of FWHM
average_trace_advised = dilate[self.smoothed_mask == 1].sum() == dilate.sum()
LOGGER.debug(f"[{self.filename}] : Branch height traces will be averaged: {average_trace_advised}")
# Iterate over the nodes and analyse the branches
matched_branches = None
branch_image = None
avg_image = np.zeros_like(self.image)
real_node_count = 0
for node_no, (node_x, node_y) in enumerate(node_coords):
unmatched_branches = {}
error = False
# Get branches relevant to the node
max_length_px = max_branch_length / (self.pixel_to_nm_scaling * 1)
reduced_node_area: npt.NDArray[np.int32] = nodeStats.only_centre_branches(
self.connected_nodes, np.array([node_x, node_y])
)
# Reduced skel graph is a networkx graph of the reduced node area.
reduced_skel_graph: nx.classes.graph.Graph = nodeStats.skeleton_image_to_graph(reduced_node_area)
# Binarise the reduced node area
branch_mask = reduced_node_area.copy()
branch_mask[branch_mask == 3] = 0
branch_mask[branch_mask == 2] = 1
node_coords = np.argwhere(reduced_node_area == 3)
# Find the starting coordinates of any branches connected to the node
branch_start_coords = self.find_branch_starts(reduced_node_area)
# Stop processing if nib (node has 2 branches)
if branch_start_coords.shape[0] <= 2:
LOGGER.debug(
f"node {node_no} has only two branches - skipped & nodes removed.{len(node_coords)}"
"pixels in nib node."
)
else:
try:
real_node_count += 1
LOGGER.debug(f"Node: {real_node_count}")
# Analyse the node branches
(
pairs,
matched_branches,
ordered_branches,
masked_image,
branch_under_over_order,
confidence,
singlet_branch_vectors,
) = nodeStats.analyse_node_branches(
p_to_nm=self.pixel_to_nm_scaling,
reduced_node_area=reduced_node_area,
branch_start_coords=branch_start_coords,
max_length_px=max_length_px,
reduced_skeleton_graph=reduced_skel_graph,
image=self.image,
average_trace_advised=average_trace_advised,
node_coord=(node_x, node_y),
pair_odd_branches=self.pair_odd_branches,
filename=self.filename,
resolution_threshold=np.float64(1000 / 512),
)
# Add the analysed branches to the labelled image
branch_image, avg_image = nodeStats.add_branches_to_labelled_image(
branch_under_over_order=branch_under_over_order,
matched_branches=matched_branches,
masked_image=masked_image,
branch_start_coords=branch_start_coords,
ordered_branches=ordered_branches,
pairs=pairs,
average_trace_advised=average_trace_advised,
image_shape=(self.image.shape[0], self.image.shape[1]),
)
# Calculate crossing angles of unpaired branches and add to stats dict
nodestats_calc_singlet_angles_result = nodeStats.calc_angles(np.asarray(singlet_branch_vectors))
angles_between_singlet_branch_vectors: npt.NDArray[np.float64] = (
nodestats_calc_singlet_angles_result[0]
)
for branch_index, angle in enumerate(angles_between_singlet_branch_vectors):
unmatched_branches[branch_index] = {"angles": angle}
# Get the vector of each branch based on ordered_coords. Ordered_coords is only the first N nm
# of the branch so this is just a general vibe on what direction a branch is going.
if len(branch_start_coords) % 2 == 0 or self.pair_odd_branches:
vectors: list[npt.NDArray[np.float64]] = []
for _, values in matched_branches.items():
vectors.append(nodeStats.get_vector(values["ordered_coords"], np.array([node_x, node_y])))
# Calculate angles between the vectors
nodestats_calc_angles_result = nodeStats.calc_angles(np.asarray(vectors))
angles_between_vectors_along_branch: npt.NDArray[np.float64] = nodestats_calc_angles_result[0]
for branch_index, angle in enumerate(angles_between_vectors_along_branch):
if len(branch_start_coords) % 2 == 0 or self.pair_odd_branches:
matched_branches[branch_index]["angles"] = angle
else:
self.image_dict["grain"]["grain_skeleton"][node_coords[:, 0], node_coords[:, 1]] = 0
# Eg: length 2 array: [array([ nan, 79.00]), array([79.00, 0.0])]
# angles_between_vectors_along_branch
except ResolutionError:
LOGGER.debug(f"Node stats skipped as resolution too low: {self.pixel_to_nm_scaling}nm per pixel")
error = True
self.node_dicts[f"node_{real_node_count}"] = {
"error": error,
"pixel_to_nm_scaling": self.pixel_to_nm_scaling,
"branch_stats": matched_branches,
"unmatched_branch_stats": unmatched_branches,
"node_coords": node_coords,
"confidence": confidence,
}
assert reduced_node_area is not None, "Reduced node area is not defined."
assert branch_image is not None, "Branch image is not defined."
assert avg_image is not None, "Average image is not defined."
node_images_dict: dict[str, npt.NDArray[np.int32]] = {
"node_area_skeleton": reduced_node_area,
"node_branch_mask": branch_image,
"node_avg_mask": avg_image,
}
self.image_dict["nodes"][f"node_{real_node_count}"] = node_images_dict
self.all_connected_nodes[self.connected_nodes != 0] = self.connected_nodes[self.connected_nodes != 0]
# pylint: disable=too-many-arguments
[docs]
@staticmethod
def add_branches_to_labelled_image(
branch_under_over_order: npt.NDArray[np.int32],
matched_branches: dict[int, MatchedBranch],
masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]],
branch_start_coords: npt.NDArray[np.int32],
ordered_branches: list[npt.NDArray[np.int32]],
pairs: npt.NDArray[np.int32],
average_trace_advised: bool,
image_shape: tuple[int, int],
) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]:
"""
Add branches to a labelled image.
Parameters
----------
branch_under_over_order : npt.NDArray[np.int32]
The order of the branches.
matched_branches : dict[int, dict[str, MatchedBranch]]
Dictionary where the key is the index of the pair and the value is a dictionary containing the following
keys:
- "ordered_coords" : npt.NDArray[np.int32].
- "heights" : npt.NDArray[np.number]. Heights of the branches.
- "distances" :
- "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches.
masked_image : dict[int, dict[str, npt.NDArray[np.bool_]]]
Dictionary where the key is the index of the pair and the value is a dictionary containing the following
keys:
- "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches.
branch_start_coords : npt.NDArray[np.int32]
An Nx2 numpy array of the coordinates of the branches connected to the node.
ordered_branches : list[npt.NDArray[np.int32]]
List of numpy arrays of ordered branch coordinates.
pairs : npt.NDArray[np.int32]
Nx2 numpy array of pairs of branches that are matched through a node.
average_trace_advised : bool
Flag to determine whether to use the average trace.
image_shape : tuple[int]
The shape of the image, to create a mask from.
Returns
-------
tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]
The branch image and the average image.
"""
branch_image: npt.NDArray[np.int32] = np.zeros(image_shape).astype(np.int32)
avg_image: npt.NDArray[np.int32] = np.zeros(image_shape).astype(np.int32)
for i, branch_index in enumerate(branch_under_over_order):
branch_coords = matched_branches[branch_index]["ordered_coords"]
# Add the matched branch to the image, starting at index 1
branch_image[branch_coords[:, 0], branch_coords[:, 1]] = i + 1
if average_trace_advised:
# For type safety, check if avg_image is None and skip if so.
# This is because the type hinting does not allow for None in the array.
avg_image[masked_image[branch_index]["avg_mask"] != 0] = i + 1
# Determine branches that were not able to be paired
unpaired_branches = np.delete(np.arange(0, branch_start_coords.shape[0]), pairs.flatten())
LOGGER.debug(f"Unpaired branches: {unpaired_branches}")
# Ensure that unpaired branches start at index I where I is the number of paired branches.
branch_label = branch_image.max()
# Add the unpaired branches back to the branch image
for i in unpaired_branches:
branch_label += 1
branch_image[ordered_branches[i][:, 0], ordered_branches[i][:, 1]] = branch_label
return branch_image, avg_image
[docs]
@staticmethod
def analyse_node_branches(
p_to_nm: np.float64,
reduced_node_area: npt.NDArray[np.int32],
branch_start_coords: npt.NDArray[np.int32],
max_length_px: np.float64,
reduced_skeleton_graph: nx.classes.graph.Graph,
image: npt.NDArray[np.number],
average_trace_advised: bool,
node_coord: tuple[np.int32, np.int32],
pair_odd_branches: bool,
filename: str,
resolution_threshold: np.float64,
) -> tuple[
npt.NDArray[np.int32],
dict[int, MatchedBranch],
list[npt.NDArray[np.int32]],
dict[int, dict[str, npt.NDArray[np.bool_]]],
npt.NDArray[np.int32],
np.float64 | None,
]:
"""
Analyse the branches of a single node.
Parameters
----------
p_to_nm : np.float64
The pixel to nm scaling factor.
reduced_node_area : npt.NDArray[np.int32]
An NxM numpy array of the node in question and the branches connected to it.
Node is marked by 3, and branches by 1.
branch_start_coords : npt.NDArray[np.int32]
An Nx2 numpy array of the coordinates of the branches connected to the node.
max_length_px : np.int32
The maximum length in pixels to traverse along while ordering.
reduced_skeleton_graph : nx.classes.graph.Graph
The graph representation of the reduced node area.
image : npt.NDArray[np.number]
The full image of the grain.
average_trace_advised : bool
Flag to determine whether to use the average trace.
node_coord : tuple[np.int32, np.int32]
The node coordinates.
pair_odd_branches : bool
Whether to try and pair odd-branched nodes.
filename : str
The filename of the image.
resolution_threshold : np.float64
The resolution threshold below which to warn the user that the node is difficult to analyse.
Returns
-------
pairs: npt.NDArray[np.int32]
Nx2 numpy array of pairs of branches that are matched through a node.
matched_branches: dict[int, MatchedBranch]]
Dictionary where the key is the index of the pair and the value is a dictionary containing the following
keys:
- "ordered_coords" : npt.NDArray[np.int32].
- "heights" : npt.NDArray[np.number]. Heights of the branches.
- "distances" : npt.NDArray[np.number]. The accumulating distance along the branch.
- "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches.
- "angles" : np.float64. The angle of the branch, added in later steps.
ordered_branches: list[npt.NDArray[np.int32]]
List of numpy arrays of ordered branch coordinates.
masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]]
Dictionary where the key is the index of the pair and the value is a dictionary containing the following
keys:
- "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches.
branch_under_over_order: npt.NDArray[np.int32]
The order of the branches based on the FWHM.
confidence: np.float64 | None
The confidence of the crossing. Optional.
"""
if not p_to_nm <= resolution_threshold:
LOGGER.debug(f"Resolution {p_to_nm} is below suggested {resolution_threshold}, node difficult to analyse.")
# Pixel-wise order the branches coming from the node and calculate the starting vector for each branch
ordered_branches, singlet_branch_vectors = nodeStats.get_ordered_branches_and_vectors(
reduced_node_area, branch_start_coords, max_length_px
)
# Pair the singlet branch vectors based on their suitability using vector orientation.
if len(branch_start_coords) % 2 == 0 or pair_odd_branches:
pairs = nodeStats.pair_vectors(np.asarray(singlet_branch_vectors))
else:
pairs = np.array([], dtype=np.int32)
# Match the branches up
matched_branches, masked_image = nodeStats.join_matching_branches_through_node(
pairs,
ordered_branches,
reduced_skeleton_graph,
image,
average_trace_advised,
node_coord,
filename,
)
# Redo the FWHMs after the processing for more accurate determination of under/overs.
hms = []
for _, values in matched_branches.items():
hms.append(values["fwhm"]["half_maxs"][2])
for _, values in matched_branches.items():
values["fwhm"] = nodeStats.calculate_fwhm(values["heights"], values["distances"], hm=max(hms))
# Get the confidence of the crossing
crossing_fwhms = []
for _, values in matched_branches.items():
crossing_fwhms.append(values["fwhm"]["fwhm"])
if len(crossing_fwhms) <= 1:
confidence = None
else:
crossing_fwhm_combinations = list(combinations(crossing_fwhms, 2))
confidence = np.float64(nodeStats.cross_confidence(crossing_fwhm_combinations))
# Order the branch indexes based on the FWHM of the branches.
branch_under_over_order = np.array(list(matched_branches.keys()))[np.argsort(np.array(crossing_fwhms))]
return (
pairs,
matched_branches,
ordered_branches,
masked_image,
branch_under_over_order,
confidence,
singlet_branch_vectors,
)
[docs]
@staticmethod
def join_matching_branches_through_node(
pairs: npt.NDArray[np.int32],
ordered_branches: list[npt.NDArray[np.int32]],
reduced_skeleton_graph: nx.classes.graph.Graph,
image: npt.NDArray[np.number],
average_trace_advised: bool,
node_coords: tuple[np.int32, np.int32],
filename: str,
) -> tuple[dict[int, MatchedBranch], dict[int, dict[str, npt.NDArray[np.bool_]]]]:
"""
Join branches that are matched through a node.
Parameters
----------
pairs : npt.NDArray[np.int32]
Nx2 numpy array of pairs of branches that are matched through a node.
ordered_branches : list[npt.NDArray[np.int32]]
List of numpy arrays of ordered branch coordinates.
reduced_skeleton_graph : nx.classes.graph.Graph
Graph representation of the skeleton.
image : npt.NDArray[np.number]
The full image of the grain.
average_trace_advised : bool
Flag to determine whether to use the average trace.
node_coords : tuple[np.int32, np.int32]
The node coordinates.
filename : str
The filename of the image.
Returns
-------
matched_branches: dict[int, dict[str, npt.NDArray[np.number]]]
Dictionary where the key is the index of the pair and the value is a dictionary containing the following
keys:
- "ordered_coords" : npt.NDArray[np.int32].
- "heights" : npt.NDArray[np.number]. Heights of the branches.
- "distances" :
- "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches.
masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]]
Dictionary where the key is the index of the pair and the value is a dictionary containing the following
keys:
- "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches.
"""
matched_branches: dict[int, MatchedBranch] = {}
masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]] = (
{}
) # Masked image is a dictionary of pairs of branches
for i, (branch_1, branch_2) in enumerate(pairs):
matched_branches[i] = MatchedBranch(
ordered_coords=np.array([], dtype=np.int32),
heights=np.array([], dtype=np.float64),
distances=np.array([], dtype=np.float64),
fwhm={},
angles=None,
)
masked_image[i] = {}
# find close ends by rearranging branch coords
branch_1_coords, branch_2_coords = nodeStats.order_branches(
ordered_branches[branch_1], ordered_branches[branch_2]
)
# Get graphical shortest path between branch ends on the skeleton
crossing = nx.shortest_path(
reduced_skeleton_graph,
source=tuple(branch_1_coords[-1]),
target=tuple(branch_2_coords[0]),
weight="weight",
)
crossing = np.asarray(crossing[1:-1]) # remove start and end points & turn into array
# Branch coords and crossing
if crossing.shape == (0,):
branch_coords = np.vstack([branch_1_coords, branch_2_coords])
else:
branch_coords = np.vstack([branch_1_coords, crossing, branch_2_coords])
# make images of single branch joined and multiple branches joined
single_branch_img: npt.NDArray[np.bool_] = np.zeros_like(image).astype(bool)
single_branch_img[branch_coords[:, 0], branch_coords[:, 1]] = True
single_branch_coords = order_branch(single_branch_img.astype(bool), [0, 0])
# calc image-wide coords
matched_branches[i]["ordered_coords"] = single_branch_coords
# get heights and trace distance of branch
try:
assert average_trace_advised
distances, heights, mask, _ = nodeStats.average_height_trace(
image, single_branch_img, single_branch_coords, [node_coords[0], node_coords[1]]
)
masked_image[i]["avg_mask"] = mask
except (
AssertionError,
IndexError,
) as e: # Assertion - avg trace not advised, Index - wiggy branches
LOGGER.debug(f"[{filename}] : avg trace failed with {e}, single trace only.")
average_trace_advised = False
distances = nodeStats.coord_dist_rad(single_branch_coords, np.array([node_coords[0], node_coords[1]]))
# distances = self.coord_dist(single_branch_coords)
zero_dist = distances[
np.argmin(
np.sqrt(
(single_branch_coords[:, 0] - node_coords[0]) ** 2
+ (single_branch_coords[:, 1] - node_coords[1]) ** 2
)
)
]
heights = image[single_branch_coords[:, 0], single_branch_coords[:, 1]] # self.hess
distances = distances - zero_dist
distances, heights = nodeStats.average_uniques(
distances, heights
) # needs to be paired with coord_dist_rad
matched_branches[i]["heights"] = heights
matched_branches[i]["distances"] = distances
# identify over/under
matched_branches[i]["fwhm"] = nodeStats.calculate_fwhm(heights, distances)
return matched_branches, masked_image
[docs]
@staticmethod
def get_ordered_branches_and_vectors(
reduced_node_area: npt.NDArray[np.int32],
branch_start_coords: npt.NDArray[np.int32],
max_length_px: np.float64,
) -> tuple[list[npt.NDArray[np.int32]], list[npt.NDArray[np.int32]]]:
"""
Get ordered branches and vectors for a node.
Branches are ordered so they are no longer just a disordered set of coordinates, and vectors are calculated to
represent the general direction tendency of the branch, this allows for alignment matching later on.
Parameters
----------
reduced_node_area : npt.NDArray[np.int32]
An NxM numpy array of the node in question and the branches connected to it.
Node is marked by 3, and branches by 1.
branch_start_coords : npt.NDArray[np.int32]
An Px2 numpy array of coordinates representing the start of branches where P is the number of branches.
max_length_px : np.int32
The maximum length in pixels to traverse along while ordering.
Returns
-------
tuple[list[npt.NDArray[np.int32]], list[npt.NDArray[np.int32]]]
A tuple containing a list of ordered branches and a list of vectors.
"""
ordered_branches = []
vectors = []
nodeless = np.where(reduced_node_area == 1, 1, 0)
for branch_start_coord in branch_start_coords:
# Order the branch coordinates so they're no longer just a disordered set of coordinates
ordered_branch = order_branch_from_start(nodeless.copy(), branch_start_coord, max_length=max_length_px)
ordered_branches.append(ordered_branch)
# Calculate vector to represent the general direction tendency of the branch (for alignment matching)
vector = nodeStats.get_vector(ordered_branch, branch_start_coord)
vectors.append(vector)
return ordered_branches, vectors
[docs]
@staticmethod
def cross_confidence(pair_combinations: list) -> float:
"""
Obtain the average confidence of the combinations using a reciprical function.
Parameters
----------
pair_combinations : list
List of length 2 combinations of FWHM values.
Returns
-------
float
The average crossing confidence.
"""
c = 0
for pair in pair_combinations:
c += nodeStats.recip(pair)
return c / len(pair_combinations)
[docs]
@staticmethod
def recip(vals: list) -> float:
"""
Compute 1 - (max / min) of the two values provided.
Parameters
----------
vals : list
List of 2 values.
Returns
-------
float
Result of applying the 1-(min / max) function to the two values.
"""
try:
if min(vals) == 0: # means fwhm variation hasn't worked
return 0
return 1 - min(vals) / max(vals)
except ZeroDivisionError:
return 0
[docs]
@staticmethod
def get_vector(coords: npt.NDArray, origin: npt.NDArray) -> npt.NDArray:
"""
Calculate the normalised vector of the coordinate means in a branch.
Parameters
----------
coords : npt.NDArray
2xN array of x, y coordinates.
origin : npt.NDArray
2x1 array of an x, y coordinate.
Returns
-------
npt.NDArray
Normalised vector from origin to the mean coordinate.
"""
vector = coords.mean(axis=0) - origin
norm = np.sqrt(vector @ vector)
return vector if norm == 0 else vector / norm # normalise vector so length=1
[docs]
@staticmethod
def calc_angles(vectors: npt.NDArray) -> npt.NDArray[np.float64]:
"""
Calculate the angles between vectors in an array.
Uses the formula:
.. code-block:: RST
cos(theta) = |a|•|b|/|a||b|
Parameters
----------
vectors : npt.NDArray
Array of 2x1 vectors.
Returns
-------
npt.NDArray
An array of the cosine of the angles between the vectors.
"""
dot = vectors @ vectors.T
norm = np.diag(dot) ** 0.5
cos_angles = dot / (norm.reshape(-1, 1) @ norm.reshape(1, -1))
np.fill_diagonal(cos_angles, 1) # ensures vector_x • vector_x angles are 0
return abs(np.arccos(cos_angles) / np.pi * 180) # angles in degrees
[docs]
@staticmethod
def pair_vectors(vectors: npt.NDArray) -> npt.NDArray[np.int32]:
"""
Take a list of vectors and pairs them based on the angle between them.
Parameters
----------
vectors : npt.NDArray
Array of 2x1 vectors to be paired.
Returns
-------
npt.NDArray
An array of the matching pair indices.
"""
# calculate cosine of angle
angles = nodeStats.calc_angles(vectors)
# match angles
return nodeStats.best_matches(angles)
[docs]
@staticmethod
def best_matches(arr: npt.NDArray, max_weight_matching: bool = True) -> npt.NDArray:
"""
Turn a matrix into a graph and calculates the best matching index pairs.
Parameters
----------
arr : npt.NDArray
Transpose symmetric MxM array where the value of index i, j represents a weight between i and j.
max_weight_matching : bool
Whether to obtain best matching pairs via maximum weight, or minimum weight matching.
Returns
-------
npt.NDArray
Array of pairs of indexes.
"""
if max_weight_matching:
G = nodeStats.create_weighted_graph(arr)
matching = np.array(list(nx.max_weight_matching(G, maxcardinality=True)))
else:
np.fill_diagonal(arr, arr.max() + 1)
G = nodeStats.create_weighted_graph(arr)
matching = np.array(list(nx.min_weight_matching(G)))
return matching
[docs]
@staticmethod
def create_weighted_graph(matrix: npt.NDArray) -> nx.Graph:
"""
Create a bipartite graph connecting i <-> j from a square matrix of weights matrix[i, j].
Parameters
----------
matrix : npt.NDArray
Square array of weights between rows and columns.
Returns
-------
nx.Graph
Bipatrite graph with edge weight i->j matching matrix[i,j].
"""
n = len(matrix)
G = nx.Graph()
for i in range(n):
for j in range(i + 1, n):
G.add_edge(i, j, weight=matrix[i, j])
return G
[docs]
@staticmethod
def pair_angles(angles: npt.NDArray) -> list:
"""
Pair angles that are 180 degrees to each other and removes them before selecting the next pair.
Parameters
----------
angles : npt.NDArray
Square array (i,j) of angles between i and j.
Returns
-------
list
A list of paired indexes in a list.
"""
angles_cp = angles.copy()
pairs = []
for _ in range(int(angles.shape[0] / 2)):
pair = np.unravel_index(np.argmax(angles_cp), angles.shape)
pairs.append(pair) # add to list
angles_cp[[pair]] = 0 # set rows 0 to avoid picking again
angles_cp[:, [pair]] = 0 # set cols 0 to avoid picking again
return np.asarray(pairs)
[docs]
@staticmethod
def gaussian(x: npt.NDArray, h: float, mean: float, sigma: float):
"""
Apply the gaussian function.
Parameters
----------
x : npt.NDArray
X values to be passed into the gaussian.
h : float
The peak height of the gaussian.
mean : float
The mean of the x values.
sigma : float
The standard deviation of the image.
Returns
-------
npt.NDArray
The y-values of the gaussian performed on the x values.
"""
return h * np.exp(-((x - mean) ** 2) / (2 * sigma**2))
[docs]
@staticmethod
def interpolate_between_yvalue(x: npt.NDArray, y: npt.NDArray, yvalue: float) -> float:
"""
Calculate the x value between the two points either side of yvalue in y.
Parameters
----------
x : npt.NDArray
An array of length y.
y : npt.NDArray
An array of length x.
yvalue : float
A value within the bounds of the y array.
Returns
-------
float
The linearly interpolated x value between the arrays.
"""
for i in range(len(y) - 1):
if y[i] <= yvalue <= y[i + 1] or y[i + 1] <= yvalue <= y[i]: # if points cross through the hm value
return nodeStats.lin_interp([x[i], y[i]], [x[i + 1], y[i + 1]], yvalue=yvalue)
return 0
[docs]
@staticmethod
def calculate_fwhm(
heights: npt.NDArray, distances: npt.NDArray, hm: float | None = None
) -> dict[str, np.float64 | list[np.float64 | float | None]]:
"""
Calculate the FWHM value.
First identifyies the HM then finding the closest values in the distances array and using
linear interpolation to calculate the FWHM.
Parameters
----------
heights : npt.NDArray
Array of heights.
distances : npt.NDArray
Array of distances.
hm : Union[None, float], optional
The halfmax value to match (if wanting the same HM between curves), by default None.
Returns
-------
tuple[float, list, list]
The FWHM value, [distance at hm for 1st half of trace, distance at hm for 2nd half of trace,
HM value], [index of the highest point, distance at highest point, height at highest point].
"""
centre_fraction = int(len(heights) * 0.2) # in case zone approaches another node, look around centre for max
if centre_fraction == 0:
high_idx = np.argmax(heights)
else:
high_idx = np.argmax(heights[centre_fraction:-centre_fraction]) + centre_fraction
# get array halves to find first points that cross hm
arr1 = heights[:high_idx][::-1]
dist1 = distances[:high_idx][::-1]
arr2 = heights[high_idx:]
dist2 = distances[high_idx:]
if hm is None:
# Get half max
hm = (heights.max() - heights.min()) / 2 + heights.min()
# half max value -> try to make it the same as other crossing branch?
# increase make hm = lowest of peak if it doesn’t hit one side
if np.min(arr1) > hm:
arr1_local_min = argrelextrema(arr1, np.less)[-1] # closest to end
try:
hm = arr1[arr1_local_min][0]
except IndexError: # index error when no local minima
hm = np.min(arr1)
elif np.min(arr2) > hm:
arr2_local_min = argrelextrema(arr2, np.less)[0] # closest to start
try:
hm = arr2[arr2_local_min][0]
except IndexError: # index error when no local minima
hm = np.min(arr2)
arr1_hm = nodeStats.interpolate_between_yvalue(x=dist1, y=arr1, yvalue=hm)
arr2_hm = nodeStats.interpolate_between_yvalue(x=dist2, y=arr2, yvalue=hm)
fwhm = np.float64(abs(arr2_hm - arr1_hm))
return {
"fwhm": fwhm,
"half_maxs": [arr1_hm, arr2_hm, hm],
"peaks": [high_idx, distances[high_idx], heights[high_idx]],
}
[docs]
@staticmethod
def lin_interp(point_1: list, point_2: list, xvalue: float | None = None, yvalue: float | None = None) -> float:
"""
Linear interp 2 points by finding line equation and subbing.
Parameters
----------
point_1 : list
List of an x and y coordinate.
point_2 : list
List of an x and y coordinate.
xvalue : Union[float, None], optional
Value at which to interpolate to get a y coordinate, by default None.
yvalue : Union[float, None], optional
Value at which to interpolate to get an x coordinate, by default None.
Returns
-------
float
Value of x or y linear interpolation.
"""
m = (point_1[1] - point_2[1]) / (point_1[0] - point_2[0])
c = point_1[1] - (m * point_1[0])
if xvalue is not None:
return m * xvalue + c # interp_y
if yvalue is not None:
return (yvalue - c) / m # interp_x
raise ValueError
[docs]
@staticmethod
def order_branches(branch1: npt.NDArray, branch2: npt.NDArray) -> tuple:
"""
Order the two ordered arrays based on the closest endpoint coordinates.
Parameters
----------
branch1 : npt.NDArray
An Nx2 array describing coordinates.
branch2 : npt.NDArray
An Nx2 array describing coordinates.
Returns
-------
tuple
An tuple with the each coordinate array ordered to follow on from one-another.
"""
endpoints1 = np.asarray([branch1[0], branch1[-1]])
endpoints2 = np.asarray([branch2[0], branch2[-1]])
sum1 = abs(endpoints1 - endpoints2).sum(axis=1)
sum2 = abs(endpoints1[::-1] - endpoints2).sum(axis=1)
if sum1.min() < sum2.min():
if np.argmin(sum1) == 0:
return branch1[::-1], branch2
return branch1, branch2[::-1]
if np.argmin(sum2) == 0:
return branch1, branch2
return branch1[::-1], branch2[::-1]
[docs]
@staticmethod
def binary_line(start: npt.NDArray, end: npt.NDArray) -> npt.NDArray:
"""
Create a binary path following the straight line between 2 points.
Parameters
----------
start : npt.NDArray
A coordinate.
end : npt.NDArray
Another coordinate.
Returns
-------
npt.NDArray
An Nx2 coordinate array that the line passes through.
"""
arr = []
m_swap = False
x_swap = False
slope = (end - start)[1] / (end - start)[0]
if abs(slope) > 1: # swap x and y if slope will cause skips
start, end = start[::-1], end[::-1]
slope = 1 / slope
m_swap = True
if start[0] > end[0]: # swap x coords if coords wrong way around
start, end = end, start
x_swap = True
# code assumes slope < 1 hence swap
x_start, y_start = start
x_end, _ = end
for x in range(x_start, x_end + 1):
y_true = slope * (x - x_start) + y_start
y_pixel = np.round(y_true)
arr.append([x, y_pixel])
if m_swap: # if swapped due to slope, return
arr = np.asarray(arr)[:, [1, 0]].reshape(-1, 2).astype(int)
if x_swap:
return arr[::-1]
return arr
arr = np.asarray(arr).reshape(-1, 2).astype(int)
if x_swap:
return arr[::-1]
return arr
[docs]
@staticmethod
def coord_dist_rad(coords: npt.NDArray, centre: npt.NDArray, pixel_to_nm_scaling: float = 1) -> npt.NDArray:
"""
Calculate the distance from the centre coordinate to a point along the ordered coordinates.
This differs to traversal along the coordinates taken. This also averages any common distance
values and makes those in the trace before the node index negative.
Parameters
----------
coords : npt.NDArray
Nx2 array of branch coordinates.
centre : npt.NDArray
A 1x2 array of the centre coordinates to identify a 0 point for the node.
pixel_to_nm_scaling : float, optional
The pixel to nanometer scaling factor to provide real units, by default 1.
Returns
-------
npt.NDArray
A Nx1 array of the distance from the node centre.
"""
diff_coords = coords - centre
if np.all(coords == centre, axis=1).sum() == 0: # if centre not in coords, reassign centre
diff_dists = np.sqrt(diff_coords[:, 0] ** 2 + diff_coords[:, 1] ** 2)
centre = coords[np.argmin(diff_dists)]
cross_idx = np.argwhere(np.all(coords == centre, axis=1))
rad_dist = np.sqrt(diff_coords[:, 0] ** 2 + diff_coords[:, 1] ** 2)
rad_dist[0 : cross_idx[0][0]] *= -1
return rad_dist * pixel_to_nm_scaling
[docs]
@staticmethod
def above_below_value_idx(array: npt.NDArray, value: float) -> list:
"""
Identify indices of the array neighbouring the specified value.
Parameters
----------
array : npt.NDArray
Array of values.
value : float
Value to identify indices between.
Returns
-------
list
List of the lower index and higher index around the value.
Raises
------
IndexError
When the value is in the array.
"""
idx1 = abs(array - value).argmin()
try:
if array[idx1] < value < array[idx1 + 1]:
idx2 = idx1 + 1
elif array[idx1 - 1] < value < array[idx1]:
idx2 = idx1 - 1
else:
raise IndexError # this will be if the number is the same
indices = [idx1, idx2]
indices.sort()
return indices
except IndexError:
return None
[docs]
@staticmethod
def average_height_trace( # noqa: C901
img: npt.NDArray, branch_mask: npt.NDArray, branch_coords: npt.NDArray, centre=(0, 0)
) -> tuple:
"""
Average two side-by-side ordered skeleton distance and height traces.
Dilate the original branch to create two additional side-by-side branches
in order to get a more accurate average of the height traces. This function produces
the common distances between these 3 branches, and their averaged heights.
Parameters
----------
img : npt.NDArray
An array of numbers pertaining to an image.
branch_mask : npt.NDArray
A binary array of the branch, must share the same dimensions as the image.
branch_coords : npt.NDArray
Ordered coordinates of the branch mask.
centre : Union[float, None]
The coordinates to centre the branch around.
Returns
-------
tuple
A tuple of the averaged heights from the linetrace and their corresponding distances
from the crossing.
"""
# get heights and dists of the original (middle) branch
branch_dist = nodeStats.coord_dist_rad(branch_coords, centre)
# branch_dist = self.coord_dist(branch_coords)
branch_heights = img[branch_coords[:, 0], branch_coords[:, 1]]
branch_dist, branch_heights = nodeStats.average_uniques(
branch_dist, branch_heights
) # needs to be paired with coord_dist_rad
dist_zero_point = branch_dist[
np.argmin(np.sqrt((branch_coords[:, 0] - centre[0]) ** 2 + (branch_coords[:, 1] - centre[1]) ** 2))
]
branch_dist_norm = branch_dist - dist_zero_point # - 0 # branch_dist[branch_heights.argmax()]
# want to get a 3 pixel line trace, one on each side of orig
dilate = binary_dilation(branch_mask, iterations=1)
dilate = nodeStats.fill_holes(dilate)
dilate_minus = np.where(dilate != branch_mask, 1, 0)
dilate2 = binary_dilation(dilate, iterations=1)
dilate2[(dilate == 1) | (branch_mask == 1)] = 0
labels = label(dilate2)
# Cleanup stages - re-entering, early terminating, closer traces
# if parallel trace out and back in zone, can get > 2 labels
labels = nodeStats._remove_re_entering_branches(labels, remaining_branches=2)
# if parallel trace doesn't exit window, can get 1 label
# occurs when skeleton has poor connections (extra branches which cut corners)
if labels.max() == 1:
conv = convolve_skeleton(branch_mask)
endpoints = np.argwhere(conv == 2)
for endpoint in endpoints: # may be >1 endpoint
para_trace_coords = np.argwhere(labels == 1)
abs_diff = np.absolute(para_trace_coords - endpoint).sum(axis=1)
min_idxs = np.where(abs_diff == abs_diff.min())
trace_coords_remove = para_trace_coords[min_idxs]
labels[trace_coords_remove[:, 0], trace_coords_remove[:, 1]] = 0
labels = label(labels)
# reduce binary dilation distance
parallel = np.zeros_like(branch_mask).astype(np.int32)
for i in range(1, labels.max() + 1):
single = labels.copy()
single[single != i] = 0
single[single == i] = 1
sing_dil = binary_dilation(single)
parallel[(sing_dil == dilate_minus) & (sing_dil == 1)] = i
labels = parallel.copy()
binary = labels.copy()
binary[binary != 0] = 1
binary += branch_mask
# get and order coords, then get heights and distances relitive to node centre / highest point
heights = []
distances = []
for i in np.unique(labels)[1:]:
trace_img = np.where(labels == i, 1, 0)
trace_img = getSkeleton(img, trace_img, method="zhang").get_skeleton()
trace = order_branch(trace_img, branch_coords[0])
height_trace = img[trace[:, 0], trace[:, 1]]
dist = nodeStats.coord_dist_rad(trace, centre) # self.coord_dist(trace)
dist, height_trace = nodeStats.average_uniques(dist, height_trace) # needs to be paired with coord_dist_rad
heights.append(height_trace)
distances.append(
dist - dist_zero_point # - 0
) # branch_dist[branch_heights.argmax()]) #dist[central_heights.argmax()])
# Make like coord system using original branch
avg1 = []
avg2 = []
for mid_dist in branch_dist_norm:
for i, (distance, height) in enumerate(zip(distances, heights)):
# check if distance already in traces array
if (mid_dist == distance).any():
idx = np.where(mid_dist == distance)
if i == 0:
avg1.append([mid_dist, height[idx][0]])
else:
avg2.append([mid_dist, height[idx][0]])
# if not, linearly interpolate the mid-branch value
else:
# get index after and before the mid branches' x coord
xidxs = nodeStats.above_below_value_idx(distance, mid_dist)
if xidxs is None:
pass # if indexes outside of range, pass
else:
point1 = [distance[xidxs[0]], height[xidxs[0]]]
point2 = [distance[xidxs[1]], height[xidxs[1]]]
y = nodeStats.lin_interp(point1, point2, xvalue=mid_dist)
if i == 0:
avg1.append([mid_dist, y])
else:
avg2.append([mid_dist, y])
avg1 = np.asarray(avg1)
avg2 = np.asarray(avg2)
# ensure arrays are same length to average
temp_x = branch_dist_norm[np.isin(branch_dist_norm, avg1[:, 0])]
common_dists = avg2[:, 0][np.isin(avg2[:, 0], temp_x)]
common_avg_branch_heights = branch_heights[np.isin(branch_dist_norm, common_dists)]
common_avg1_heights = avg1[:, 1][np.isin(avg1[:, 0], common_dists)]
common_avg2_heights = avg2[:, 1][np.isin(avg2[:, 0], common_dists)]
average_heights = (common_avg_branch_heights + common_avg1_heights + common_avg2_heights) / 3
return (
common_dists,
average_heights,
binary,
[[heights[0], branch_heights, heights[1]], [distances[0], branch_dist_norm, distances[1]]],
)
[docs]
@staticmethod
def fill_holes(mask: npt.NDArray) -> npt.NDArray:
"""
Fill all holes within a binary mask.
Parameters
----------
mask : npt.NDArray
Binary array of object.
Returns
-------
npt.NDArray
Binary array of object with any interior holes filled in.
"""
inv_mask = np.where(mask != 0, 0, 1)
lbl_inv = label(inv_mask, connectivity=1)
idxs, counts = np.unique(lbl_inv, return_counts=True)
max_idx = idxs[np.argmax(counts)]
return np.where(lbl_inv != max_idx, 1, 0)
[docs]
@staticmethod
def _remove_re_entering_branches(mask: npt.NDArray, remaining_branches: int = 1) -> npt.NDArray:
"""
Remove smallest branches which branches exit and re-enter the viewing area.
Contninues until only <remaining_branches> remain.
Parameters
----------
mask : npt.NDArray
Skeletonised binary mask of an object.
remaining_branches : int, optional
Number of objects (branches) to keep, by default 1.
Returns
-------
npt.NDArray
Mask with only a single skeletonised branch.
"""
rtn_image = mask.copy()
binary_image = mask.copy()
binary_image[binary_image != 0] = 1
labels = label(binary_image)
if labels.max() > remaining_branches:
lens = [labels[labels == i].size for i in range(1, labels.max() + 1)]
while len(lens) > remaining_branches:
smallest_idx = min(enumerate(lens), key=lambda x: x[1])[0]
rtn_image[labels == smallest_idx + 1] = 0
lens.remove(min(lens))
return rtn_image
[docs]
@staticmethod
def only_centre_branches(node_image: npt.NDArray, node_coordinate: npt.NDArray) -> npt.NDArray[np.int32]:
"""
Remove all branches not connected to the current node.
Parameters
----------
node_image : npt.NDArray
An image of the skeletonised area surrounding the node where
the background = 0, skeleton = 1, termini = 2, nodes = 3.
node_coordinate : npt.NDArray
2x1 coordinate describing the position of a node.
Returns
-------
npt.NDArray[np.int32]
The initial node image but only with skeletal branches
connected to the middle node.
"""
node_image_cp = node_image.copy()
# get node-only image
nodes = node_image_cp.copy()
nodes[nodes != 3] = 0
labeled_nodes = label(nodes)
# find which cluster is closest to the centre
node_coords = np.argwhere(nodes == 3)
min_coords = node_coords[abs(node_coords - node_coordinate).sum(axis=1).argmin()]
centre_idx = labeled_nodes[min_coords[0], min_coords[1]]
# get nodeless image
nodeless = node_image_cp.copy()
nodeless = np.where(
(node_image == 1) | (node_image == 2), 1, 0
) # if termini, need this in the labeled branches too
nodeless[labeled_nodes == centre_idx] = 1 # return centre node
labeled_nodeless = label(nodeless)
# apply to return image
for i in range(1, labeled_nodeless.max() + 1):
if (node_image_cp[labeled_nodeless == i] == 3).any():
node_image_cp[labeled_nodeless != i] = 0
break
# remove small area around other nodes
labeled_nodes[labeled_nodes == centre_idx] = 0
non_central_node_coords = np.argwhere(labeled_nodes != 0)
for coord in non_central_node_coords:
for j, coord_val in enumerate(coord):
if coord_val - 1 < 0:
coord[j] = 1
if coord_val + 2 > node_image_cp.shape[j]:
coord[j] = node_image_cp.shape[j] - 2
node_image_cp[coord[0] - 1 : coord[0] + 2, coord[1] - 1 : coord[1] + 2] = 0
return node_image_cp
[docs]
@staticmethod
def average_uniques(arr1: npt.NDArray, arr2: npt.NDArray) -> tuple:
"""
Obtain the unique values of both arrays, and the average of common values.
Parameters
----------
arr1 : npt.NDArray
An array.
arr2 : npt.NDArray
An array.
Returns
-------
tuple
The unique values of both arrays, and the averaged common values.
"""
arr1_uniq, index = np.unique(arr1, return_index=True)
arr2_new = np.zeros_like(arr1_uniq).astype(np.float64)
for i, val in enumerate(arr1[index]):
mean = arr2[arr1 == val].mean()
arr2_new[i] += mean
return arr1[index], arr2_new
[docs]
@staticmethod
def average_crossing_confs(node_dict) -> None | float:
"""
Return the average crossing confidence of all crossings in the molecule.
Parameters
----------
node_dict : dict
A dictionary containing node statistics and information.
Returns
-------
Union[None, float]
The value of minimum confidence or none if not possible.
"""
sum_conf = 0
valid_confs = 0
for _, (_, values) in enumerate(node_dict.items()):
confidence = values["confidence"]
if confidence is not None:
sum_conf += confidence
valid_confs += 1
try:
return sum_conf / valid_confs
except ZeroDivisionError:
return None
[docs]
@staticmethod
def minimum_crossing_confs(node_dict: dict) -> None | float:
"""
Return the minimum crossing confidence of all crossings in the molecule.
Parameters
----------
node_dict : dict
A dictionary containing node statistics and information.
Returns
-------
Union[None, float]
The value of minimum confidence or none if not possible.
"""
confidences = []
valid_confs = 0
for _, (_, values) in enumerate(node_dict.items()):
confidence = values["confidence"]
if confidence is not None:
confidences.append(confidence)
valid_confs += 1
try:
return min(confidences)
except ValueError:
return None
[docs]
def compile_metrics(self) -> None:
"""Add the number of crossings, average and minimum crossing confidence to the metrics dictionary."""
self.metrics["num_crossings"] = np.int64((self.node_centre_mask == 3).sum())
self.metrics["avg_crossing_confidence"] = np.float64(nodeStats.average_crossing_confs(self.node_dicts))
self.metrics["min_crossing_confidence"] = np.float64(nodeStats.minimum_crossing_confs(self.node_dicts))
[docs]
def nodestats_image(
image: npt.NDArray,
disordered_tracing_direction_data: dict,
filename: str,
pixel_to_nm_scaling: float,
node_joining_length: float,
node_extend_dist: float,
branch_pairing_length: float,
pair_odd_branches: float,
pad_width: int,
) -> tuple:
"""
Initialise the nodeStats class.
Parameters
----------
image : npt.NDArray
The array of pixels.
disordered_tracing_direction_data : dict
The images and bbox coordinates of the pruned skeletons.
filename : str
The name of the file being processed. For logging purposes.
pixel_to_nm_scaling : float
The pixel to nm scaling factor.
node_joining_length : float
The length over which to join skeletal intersections to be counted as one crossing.
node_joining_length : float
The distance over which to join nearby odd-branched nodes.
node_extend_dist : float
The distance under which to join odd-branched node regions.
branch_pairing_length : float
The length from the crossing point to pair and trace, obtaining FWHM's.
pair_odd_branches : bool
Whether to try and pair odd-branched nodes.
pad_width : int
The number of edge pixels to pad the image by.
Returns
-------
tuple[dict, pd.DataFrame, dict, dict]
The nodestats statistics for each crossing, crossing statistics to be added to the grain statistics,
an image dictionary of nodestats steps for the entire image, and single grain images.
"""
n_grains = len(disordered_tracing_direction_data)
img_base = np.zeros_like(image)
nodestats_data = {}
# want to get each cropped image, use some anchor coords to match them onto the image,
# and compile all the grain images onto a single image
all_images = {
"convolved_skeletons": img_base.copy(),
"node_centres": img_base.copy(),
"connected_nodes": img_base.copy(),
}
nodestats_branch_images = {}
grainstats_additions = {}
LOGGER.info(f"[{filename}] : Calculating NodeStats statistics for {n_grains} grains...")
for n_grain, disordered_tracing_grain_data in disordered_tracing_direction_data.items():
nodestats = None # reset the nodestats variable
try:
nodestats = nodeStats(
image=disordered_tracing_grain_data["original_image"],
mask=disordered_tracing_grain_data["original_grain"],
smoothed_mask=disordered_tracing_grain_data["smoothed_grain"],
skeleton=disordered_tracing_grain_data["pruned_skeleton"],
pixel_to_nm_scaling=pixel_to_nm_scaling,
filename=filename,
n_grain=n_grain,
node_joining_length=node_joining_length,
node_extend_dist=node_extend_dist,
branch_pairing_length=branch_pairing_length,
pair_odd_branches=pair_odd_branches,
)
nodestats_dict, node_image_dict = nodestats.get_node_stats()
LOGGER.debug(f"[{filename}] : Nodestats processed {n_grain} of {n_grains}")
# compile images
nodestats_images = {
"convolved_skeletons": nodestats.conv_skelly,
"node_centres": nodestats.node_centre_mask,
"connected_nodes": nodestats.connected_nodes,
}
nodestats_branch_images[n_grain] = node_image_dict
# compile metrics
grainstats_additions[n_grain] = {
"image": filename,
"grain_number": int(n_grain.split("_")[-1]),
}
grainstats_additions[n_grain].update(nodestats.metrics)
if nodestats_dict: # if the grain's nodestats dict is not empty
nodestats_data[n_grain] = nodestats_dict
# remap the cropped images back onto the original
for image_name, full_image in all_images.items():
crop = nodestats_images[image_name]
bbox = disordered_tracing_grain_data["bbox"]
full_image[bbox[0] : bbox[2], bbox[1] : bbox[3]] += crop[pad_width:-pad_width, pad_width:-pad_width]
except Exception as e: # pylint: disable=broad-exception-caught
LOGGER.error(
f"[{filename}] : Nodestats for {n_grain} failed. Consider raising an issue on GitHub. Error: ",
exc_info=e,
)
nodestats_data[n_grain] = {}
# turn the grainstats additions into a dataframe, # might need to do something for when everything is empty
grainstats_additions_df = pd.DataFrame.from_dict(grainstats_additions, orient="index")
return nodestats_data, grainstats_additions_df, all_images, nodestats_branch_images