Source code for topostats.tracing.nodestats

"""Perform Crossing Region Processing and Analysis."""

from __future__ import annotations

import logging
from itertools import combinations
from typing import TypedDict

import networkx as nx
import numpy as np
import numpy.typing as npt
import pandas as pd
from scipy.ndimage import binary_dilation
from scipy.signal import argrelextrema
from skimage.morphology import label

from topostats.logs.logs import LOGGER_NAME
from topostats.measure.geometry import (
    calculate_shortest_branch_distances,
    connect_best_matches,
    find_branches_for_nodes,
)
from topostats.tracing.pruning import prune_skeleton
from topostats.tracing.skeletonize import getSkeleton
from topostats.tracing.tracingfuncs import order_branch, order_branch_from_start
from topostats.utils import ResolutionError, convolve_skeleton

LOGGER = logging.getLogger(LOGGER_NAME)

# pylint: disable=too-many-arguments
# pylint: disable=too-many-branches
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-lines
# pylint: disable=too-many-locals
# pylint: disable=too-many-nested-blocks
# pylint: disable=too-many-public-methods
# pylint: disable=too-many-statements


[docs] class NodeDict(TypedDict): """Dictionary containing the node information.""" error: bool pixel_to_nm_scaling: np.float64 branch_stats: dict[int, MatchedBranch] | None node_coords: npt.NDArray[np.int32] | None confidence: np.float64 | None
[docs] class MatchedBranch(TypedDict): """ Dictionary containing the matched branches. matched_branches: dict[int, dict[str, npt.NDArray[np.number]]] Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "ordered_coords" : npt.NDArray[np.int32]. The ordered coordinates of the branch. - "heights" : npt.NDArray[np.number]. Heights of the branch coordinates. - "distances" : npt.NDArray[np.number]. Distances of the branch coordinates. - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branch. - "angles" : np.float64. The initial direction angle of the branch, added in later steps. """ ordered_coords: npt.NDArray[np.int32] heights: npt.NDArray[np.number] distances: npt.NDArray[np.number] fwhm: dict[str, np.float64 | tuple[np.float64]] angles: np.float64 | None
[docs] class ImageDict(TypedDict): """Dictionary containing the image information.""" nodes: dict[str, dict[str, npt.NDArray[np.int32]]] grain: dict[str, npt.NDArray[np.int32] | dict[str, npt.NDArray[np.int32]]]
[docs] class nodeStats: """ Class containing methods to find and analyse the nodes/crossings within a grain. Parameters ---------- filename : str The name of the file being processed. For logging purposes. image : npt.npt.NDArray The array of pixels. mask : npt.npt.NDArray The binary segmentation mask. smoothed_mask : npt.NDArray A smoothed version of the bianary segmentation mask. skeleton : npt.NDArray A binary single-pixel wide mask of objects in the 'image'. pixel_to_nm_scaling : np.float32 The pixel to nm scaling factor. n_grain : int The grain number. node_joining_length : float The length over which to join skeletal intersections to be counted as one crossing. node_joining_length : float The distance over which to join nearby odd-branched nodes. node_extend_dist : float The distance under which to join odd-branched node regions. branch_pairing_length : float The length from the crossing point to pair and trace, obtaining FWHM's. pair_odd_branches : bool Whether to try and pair odd-branched nodes. """ def __init__( self, filename: str, image: npt.NDArray, mask: npt.NDArray, smoothed_mask: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: np.float64, n_grain: int, node_joining_length: float, node_extend_dist: float, branch_pairing_length: float, pair_odd_branches: bool, ) -> None: """ Initialise the nodeStats class. Parameters ---------- filename : str The name of the file being processed. For logging purposes. image : npt.NDArray The array of pixels. mask : npt.NDArray The binary segmentation mask. smoothed_mask : npt.NDArray A smoothed version of the bianary segmentation mask. skeleton : npt.NDArray A binary single-pixel wide mask of objects in the 'image'. pixel_to_nm_scaling : float The pixel to nm scaling factor. n_grain : int The grain number. node_joining_length : float The length over which to join skeletal intersections to be counted as one crossing. node_joining_length : float The distance over which to join nearby odd-branched nodes. node_extend_dist : float The distance under which to join odd-branched node regions. branch_pairing_length : float The length from the crossing point to pair and trace, obtaining FWHM's. pair_odd_branches : bool Whether to try and pair odd-branched nodes. """ self.filename = filename self.image = image self.mask = mask self.smoothed_mask = smoothed_mask # only used to average traces self.skeleton = skeleton self.pixel_to_nm_scaling = pixel_to_nm_scaling self.n_grain = n_grain self.node_joining_length = node_joining_length self.node_extend_dist = node_extend_dist / self.pixel_to_nm_scaling self.branch_pairing_length = branch_pairing_length self.pair_odd_branches = pair_odd_branches self.conv_skelly = np.zeros_like(self.skeleton) self.connected_nodes = np.zeros_like(self.skeleton) self.all_connected_nodes = np.zeros_like(self.skeleton) self.whole_skel_graph: nx.classes.graph.Graph | None = None self.node_centre_mask = np.zeros_like(self.skeleton) self.metrics = { "num_crossings": np.int64(0), "avg_crossing_confidence": None, "min_crossing_confidence": None, } self.node_dicts: dict[str, NodeDict] = {} self.image_dict: ImageDict = { "nodes": {}, "grain": { "grain_image": self.image, "grain_mask": self.mask, "grain_skeleton": self.skeleton, }, } self.full_dict = {} self.mol_coords = {} self.visuals = {} self.all_visuals_img = None
[docs] def get_node_stats(self) -> tuple[dict, dict]: """ Run the workflow to obtain the node statistics. .. code-block:: RST node_dict key structure: <grain_number> └-> <node_number> |-> 'error' └-> 'node_coords' └-> 'branch_stats' └-> <branch_number> |-> 'ordered_coords' |-> 'heights' |-> 'gaussian_fit' |-> 'fwhm' └-> 'angles' image_dict key structure: 'nodes' <node_number> |-> 'node_area_skeleton' |-> 'node_branch_mask' └-> 'node_avg_mask 'grain' |-> 'grain_image' |-> 'grain_mask' └-> 'grain_skeleton' Returns ------- tuple[dict, dict] Dictionaries of the node_information and images. """ LOGGER.debug(f"Node Stats - Processing Grain: {self.n_grain}") self.conv_skelly = convolve_skeleton(self.skeleton) if len(self.conv_skelly[self.conv_skelly == 3]) != 0: # check if any nodes LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} contains crossings.") # convolve to see crossing and end points # self.conv_skelly = self.tidy_branches(self.conv_skelly, self.image) # reset skeleton var as tidy branches may have modified it self.skeleton = np.where(self.conv_skelly != 0, 1, 0) self.image_dict["grain"]["grain_skeleton"] = self.skeleton # get graph of skeleton self.whole_skel_graph = self.skeleton_image_to_graph(self.skeleton) # connect the close nodes LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} connecting close nodes.") self.connected_nodes = self.connect_close_nodes(self.conv_skelly, node_width=self.node_joining_length) # connect the odd-branch nodes self.connected_nodes = self.connect_extended_nodes_nearest( self.connected_nodes, node_extend_dist=self.node_extend_dist ) # obtain a mask of node centers and their count self.node_centre_mask = self.highlight_node_centres(self.connected_nodes) # Begin the hefty crossing analysis LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} analysing found crossings.") self.analyse_nodes(max_branch_length=self.branch_pairing_length) self.compile_metrics() else: LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} has no crossings.") return self.node_dicts, self.image_dict
# self.all_visuals_img = dnaTrace.concat_images_in_dict(self.image.shape, self.visuals)
[docs] @staticmethod def skeleton_image_to_graph(skeleton: npt.NDArray) -> nx.classes.graph.Graph: """ Convert a skeletonised mask into a Graph representation. Graphs conserve the coordinates via the node label. Parameters ---------- skeleton : npt.NDArray A binary single-pixel wide mask, or result from conv_skelly(). Returns ------- nx.classes.graph.Graph A networkX graph connecting the pixels in the skeleton to their neighbours. """ skeImPos = np.argwhere(skeleton).T g = nx.Graph() neigh = np.array([[0, 1], [0, -1], [1, 0], [-1, 0], [1, 1], [1, -1], [-1, 1], [-1, -1]]) for idx in range(skeImPos[0].shape[0]): for neighIdx in range(neigh.shape[0]): curNeighPos = skeImPos[:, idx] + neigh[neighIdx] if np.any(curNeighPos < 0) or np.any(curNeighPos >= skeleton.shape): continue if skeleton[curNeighPos[0], curNeighPos[1]] > 0: idx_coord = skeImPos[0, idx], skeImPos[1, idx] curNeigh_coord = curNeighPos[0], curNeighPos[1] # assign lower weight to nodes if not a binary image if skeleton[idx_coord] == 3 and skeleton[curNeigh_coord] == 3: weight = 0 else: weight = 1 g.add_edge(idx_coord, curNeigh_coord, weight=weight) g.graph["physicalPos"] = skeImPos.T return g
[docs] @staticmethod def graph_to_skeleton_image(g: nx.Graph, im_shape: tuple[int]) -> npt.NDArray: """ Convert the skeleton graph back to a binary image. Parameters ---------- g : nx.Graph Graph with coordinates as node labels. im_shape : tuple[int] The shape of the image to dump. Returns ------- npt.NDArray Skeleton binary image from the graph representation. """ im = np.zeros(im_shape) for node in g: im[node] = 1 return im
[docs] def tidy_branches(self, connect_node_mask: npt.NDArray, image: npt.NDArray) -> npt.NDArray: """ Wrangle distant connected nodes back towards the main cluster. Works by filling and reskeletonising soely the node areas. Parameters ---------- connect_node_mask : npt.NDArray The connected node mask - a skeleton where node regions = 3, endpoints = 2, and skeleton = 1. image : npt.NDArray The intensity image. Returns ------- npt.NDArray The wrangled connected_node_mask. """ new_skeleton = np.where(connect_node_mask != 0, 1, 0) labeled_nodes = label(np.where(connect_node_mask == 3, 1, 0)) for node_num in range(1, labeled_nodes.max() + 1): solo_node = np.where(labeled_nodes == node_num, 1, 0) coords = np.argwhere(solo_node == 1) node_centre = coords.mean(axis=0).astype(np.int32) node_wid = coords[:, 0].max() - coords[:, 0].min() + 2 # +2 so always 2 by default node_len = coords[:, 1].max() - coords[:, 1].min() + 2 # +2 so always 2 by default overflow = int(10 / self.pixel_to_nm_scaling) if int(10 / self.pixel_to_nm_scaling) != 0 else 1 # grain mask fill new_skeleton[ node_centre[0] - node_wid // 2 - overflow : node_centre[0] + node_wid // 2 + overflow, node_centre[1] - node_len // 2 - overflow : node_centre[1] + node_len // 2 + overflow, ] = self.mask[ node_centre[0] - node_wid // 2 - overflow : node_centre[0] + node_wid // 2 + overflow, node_centre[1] - node_len // 2 - overflow : node_centre[1] + node_len // 2 + overflow, ] # remove any artifacts of the grain caught in the overflow areas new_skeleton = self.keep_biggest_object(new_skeleton) # Re-skeletonise new_skeleton = getSkeleton(image, new_skeleton, method="topostats", height_bias=0.6).get_skeleton() # new_skeleton = pruneSkeleton(image, new_skeleton).prune_skeleton( # {"method": "topostats", "max_length": -1} # ) new_skeleton = prune_skeleton( image, new_skeleton, self.pixel_to_nm_scaling, **{"method": "topostats", "max_length": -1} ) # cleanup around nibs new_skeleton = getSkeleton(image, new_skeleton, method="zhang").get_skeleton() # might also need to remove segments that have squares connected return convolve_skeleton(new_skeleton)
[docs] @staticmethod def keep_biggest_object(mask: npt.NDArray) -> npt.NDArray: """ Retain the largest object in a binary mask. Parameters ---------- mask : npt.NDArray Binary mask. Returns ------- npt.NDArray A binary mask with only one object. """ labelled_mask = label(mask) idxs, counts = np.unique(mask, return_counts=True) try: max_idx = idxs[np.argmax(counts[1:]) + 1] return np.where(labelled_mask == max_idx, 1, 0) except ValueError as e: LOGGER.debug(f"{e}: mask is empty.") return mask
[docs] def connect_close_nodes(self, conv_skelly: npt.NDArray, node_width: float = 2.85) -> npt.NDArray: """ Connect nodes within the 'node_width' boundary distance. This labels them as part of the same node. Parameters ---------- conv_skelly : npt.NDArray A labeled skeleton image with skeleton = 1, endpoints = 2, crossing points =3. node_width : float The width of the dna in the grain, used to connect close nodes. Returns ------- np.ndarray The skeleton (label=1) with close nodes connected (label=3). """ self.connected_nodes = conv_skelly.copy() nodeless = conv_skelly.copy() nodeless[(nodeless == 3) | (nodeless == 2)] = 0 # remove node & termini points nodeless_labels = label(nodeless) for i in range(1, nodeless_labels.max() + 1): if nodeless[nodeless_labels == i].size < (node_width / self.pixel_to_nm_scaling): # maybe also need to select based on height? and also ensure small branches classified self.connected_nodes[nodeless_labels == i] = 3 return self.connected_nodes
[docs] def highlight_node_centres(self, mask: npt.NDArray) -> npt.NDArray: """ Calculate the node centres based on height and re-plot on the mask. Parameters ---------- mask : npt.NDArray 2D array with background = 0, skeleton = 1, endpoints = 2, node_centres = 3. Returns ------- npt.NDArray 2D array with the highest node coordinate for each node labeled as 3. """ small_node_mask = mask.copy() small_node_mask[mask == 3] = 1 # remap nodes to skeleton big_nodes = mask.copy() big_nodes = np.where(mask == 3, 1, 0) # remove non-nodes & set nodes to 1 big_node_mask = label(big_nodes) for i in np.delete(np.unique(big_node_mask), 0): # get node indices centre = np.unravel_index((self.image * (big_node_mask == i).astype(int)).argmax(), self.image.shape) small_node_mask[centre] = 3 return small_node_mask
[docs] def connect_extended_nodes_nearest( self, connected_nodes: npt.NDArray, node_extend_dist: float = -1 ) -> npt.NDArray[np.int32]: """ Extend the odd branched nodes to other odd branched nodes within the 'extend_dist' threshold. Parameters ---------- connected_nodes : npt.NDArray A 2D array representing the network with background = 0, skeleton = 1, endpoints = 2, node_centres = 3. node_extend_dist : int | float, optional The distance over which to connect odd-branched nodes, by default -1 for no-limit. Returns ------- npt.NDArray[np.int32] Connected nodes array with odd-branched nodes connected. """ just_nodes = np.where(connected_nodes == 3, 1, 0) # remove branches & termini points labelled_nodes = label(just_nodes) just_branches = np.where(connected_nodes == 1, 1, 0) # remove node & termini points just_branches[connected_nodes == 1] = labelled_nodes.max() + 1 labelled_branches = label(just_branches) nodes_with_branch_starting_coords = find_branches_for_nodes( network_array_representation=connected_nodes, labelled_nodes=labelled_nodes, labelled_branches=labelled_branches, ) # If there is only one node, then there is no need to connect the nodes since there is nothing to # connect it to. Return the original connected_nodes instead. if len(nodes_with_branch_starting_coords) <= 1: self.connected_nodes = connected_nodes return self.connected_nodes assert self.whole_skel_graph is not None, "Whole skeleton graph is not defined." # for type safety shortest_node_dists, shortest_dists_branch_idxs, _shortest_dist_coords = calculate_shortest_branch_distances( nodes_with_branch_starting_coords=nodes_with_branch_starting_coords, whole_skeleton_graph=self.whole_skel_graph, ) # Matches is an Nx2 numpy array of indexes of the best matching nodes. # Eg: np.array([[1, 0], [2, 3]]) means that the best matching nodes are # node 1 and node 0, and node 2 and node 3. matches: npt.NDArray[np.int32] = self.best_matches(shortest_node_dists, max_weight_matching=False) # Connect the nodes by their best matches, using the shortest distances between their branch starts. connected_nodes = connect_best_matches( network_array_representation=connected_nodes, whole_skeleton_graph=self.whole_skel_graph, match_indexes=matches, shortest_distances_between_nodes=shortest_node_dists, shortest_distances_branch_indexes=shortest_dists_branch_idxs, emanating_branch_starts_by_node=nodes_with_branch_starting_coords, extend_distance=node_extend_dist, ) self.connected_nodes = connected_nodes return self.connected_nodes
[docs] @staticmethod def find_branch_starts(reduced_node_image: npt.NDArray) -> npt.NDArray: """ Find the coordinates where the branches connect to the node region through binary dilation of the node. Parameters ---------- reduced_node_image : npt.NDArray A 2D numpy array containing a single node region (=3) and its connected branches (=1). Returns ------- npt.NDArray Coordinate array of pixels next to crossing points (=3 in input). """ node = np.where(reduced_node_image == 3, 1, 0) nodeless = np.where(reduced_node_image == 1, 1, 0) thick_node = binary_dilation(node, structure=np.ones((3, 3))) return np.argwhere(thick_node * nodeless == 1)
# pylint: disable=too-many-locals
[docs] def analyse_nodes(self, max_branch_length: float = 20) -> None: """ Obtain the main analyses for the nodes of a single molecule along the 'max_branch_length'(nm) from the node. Parameters ---------- max_branch_length : float The side length of the box around the node to analyse (in nm). """ # Get coordinates of nodes # This is a numpy array of coords, shape Nx2 assert self.node_centre_mask is not None, "Node centre mask is not defined." node_coords: npt.NDArray[np.int32] = np.argwhere(self.node_centre_mask.copy() == 3) # Check whether average trace resides inside the grain mask # Checks if we dilate the skeleton once or twice, then all the pixels should fit in the grain mask dilate = binary_dilation(self.skeleton, iterations=2) # This flag determines whether to use average of 3 traces in calculation of FWHM average_trace_advised = dilate[self.smoothed_mask == 1].sum() == dilate.sum() LOGGER.debug(f"[{self.filename}] : Branch height traces will be averaged: {average_trace_advised}") # Iterate over the nodes and analyse the branches matched_branches = None branch_image = None avg_image = np.zeros_like(self.image) real_node_count = 0 for node_no, (node_x, node_y) in enumerate(node_coords): unmatched_branches = {} error = False # Get branches relevant to the node max_length_px = max_branch_length / (self.pixel_to_nm_scaling * 1) reduced_node_area: npt.NDArray[np.int32] = nodeStats.only_centre_branches( self.connected_nodes, np.array([node_x, node_y]) ) # Reduced skel graph is a networkx graph of the reduced node area. reduced_skel_graph: nx.classes.graph.Graph = nodeStats.skeleton_image_to_graph(reduced_node_area) # Binarise the reduced node area branch_mask = reduced_node_area.copy() branch_mask[branch_mask == 3] = 0 branch_mask[branch_mask == 2] = 1 node_coords = np.argwhere(reduced_node_area == 3) # Find the starting coordinates of any branches connected to the node branch_start_coords = self.find_branch_starts(reduced_node_area) # Stop processing if nib (node has 2 branches) if branch_start_coords.shape[0] <= 2: LOGGER.debug( f"node {node_no} has only two branches - skipped & nodes removed.{len(node_coords)}" "pixels in nib node." ) else: try: real_node_count += 1 LOGGER.debug(f"Node: {real_node_count}") # Analyse the node branches ( pairs, matched_branches, ordered_branches, masked_image, branch_under_over_order, confidence, singlet_branch_vectors, ) = nodeStats.analyse_node_branches( p_to_nm=self.pixel_to_nm_scaling, reduced_node_area=reduced_node_area, branch_start_coords=branch_start_coords, max_length_px=max_length_px, reduced_skeleton_graph=reduced_skel_graph, image=self.image, average_trace_advised=average_trace_advised, node_coord=(node_x, node_y), pair_odd_branches=self.pair_odd_branches, filename=self.filename, resolution_threshold=np.float64(1000 / 512), ) # Add the analysed branches to the labelled image branch_image, avg_image = nodeStats.add_branches_to_labelled_image( branch_under_over_order=branch_under_over_order, matched_branches=matched_branches, masked_image=masked_image, branch_start_coords=branch_start_coords, ordered_branches=ordered_branches, pairs=pairs, average_trace_advised=average_trace_advised, image_shape=(self.image.shape[0], self.image.shape[1]), ) # Calculate crossing angles of unpaired branches and add to stats dict nodestats_calc_singlet_angles_result = nodeStats.calc_angles(np.asarray(singlet_branch_vectors)) angles_between_singlet_branch_vectors: npt.NDArray[np.float64] = ( nodestats_calc_singlet_angles_result[0] ) for branch_index, angle in enumerate(angles_between_singlet_branch_vectors): unmatched_branches[branch_index] = {"angles": angle} # Get the vector of each branch based on ordered_coords. Ordered_coords is only the first N nm # of the branch so this is just a general vibe on what direction a branch is going. if len(branch_start_coords) % 2 == 0 or self.pair_odd_branches: vectors: list[npt.NDArray[np.float64]] = [] for _, values in matched_branches.items(): vectors.append(nodeStats.get_vector(values["ordered_coords"], np.array([node_x, node_y]))) # Calculate angles between the vectors nodestats_calc_angles_result = nodeStats.calc_angles(np.asarray(vectors)) angles_between_vectors_along_branch: npt.NDArray[np.float64] = nodestats_calc_angles_result[0] for branch_index, angle in enumerate(angles_between_vectors_along_branch): if len(branch_start_coords) % 2 == 0 or self.pair_odd_branches: matched_branches[branch_index]["angles"] = angle else: self.image_dict["grain"]["grain_skeleton"][node_coords[:, 0], node_coords[:, 1]] = 0 # Eg: length 2 array: [array([ nan, 79.00]), array([79.00, 0.0])] # angles_between_vectors_along_branch except ResolutionError: LOGGER.debug(f"Node stats skipped as resolution too low: {self.pixel_to_nm_scaling}nm per pixel") error = True self.node_dicts[f"node_{real_node_count}"] = { "error": error, "pixel_to_nm_scaling": self.pixel_to_nm_scaling, "branch_stats": matched_branches, "unmatched_branch_stats": unmatched_branches, "node_coords": node_coords, "confidence": confidence, } assert reduced_node_area is not None, "Reduced node area is not defined." assert branch_image is not None, "Branch image is not defined." assert avg_image is not None, "Average image is not defined." node_images_dict: dict[str, npt.NDArray[np.int32]] = { "node_area_skeleton": reduced_node_area, "node_branch_mask": branch_image, "node_avg_mask": avg_image, } self.image_dict["nodes"][f"node_{real_node_count}"] = node_images_dict self.all_connected_nodes[self.connected_nodes != 0] = self.connected_nodes[self.connected_nodes != 0]
# pylint: disable=too-many-arguments
[docs] @staticmethod def add_branches_to_labelled_image( branch_under_over_order: npt.NDArray[np.int32], matched_branches: dict[int, MatchedBranch], masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]], branch_start_coords: npt.NDArray[np.int32], ordered_branches: list[npt.NDArray[np.int32]], pairs: npt.NDArray[np.int32], average_trace_advised: bool, image_shape: tuple[int, int], ) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]: """ Add branches to a labelled image. Parameters ---------- branch_under_over_order : npt.NDArray[np.int32] The order of the branches. matched_branches : dict[int, dict[str, MatchedBranch]] Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "ordered_coords" : npt.NDArray[np.int32]. - "heights" : npt.NDArray[np.number]. Heights of the branches. - "distances" : - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches. masked_image : dict[int, dict[str, npt.NDArray[np.bool_]]] Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches. branch_start_coords : npt.NDArray[np.int32] An Nx2 numpy array of the coordinates of the branches connected to the node. ordered_branches : list[npt.NDArray[np.int32]] List of numpy arrays of ordered branch coordinates. pairs : npt.NDArray[np.int32] Nx2 numpy array of pairs of branches that are matched through a node. average_trace_advised : bool Flag to determine whether to use the average trace. image_shape : tuple[int] The shape of the image, to create a mask from. Returns ------- tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]] The branch image and the average image. """ branch_image: npt.NDArray[np.int32] = np.zeros(image_shape).astype(np.int32) avg_image: npt.NDArray[np.int32] = np.zeros(image_shape).astype(np.int32) for i, branch_index in enumerate(branch_under_over_order): branch_coords = matched_branches[branch_index]["ordered_coords"] # Add the matched branch to the image, starting at index 1 branch_image[branch_coords[:, 0], branch_coords[:, 1]] = i + 1 if average_trace_advised: # For type safety, check if avg_image is None and skip if so. # This is because the type hinting does not allow for None in the array. avg_image[masked_image[branch_index]["avg_mask"] != 0] = i + 1 # Determine branches that were not able to be paired unpaired_branches = np.delete(np.arange(0, branch_start_coords.shape[0]), pairs.flatten()) LOGGER.debug(f"Unpaired branches: {unpaired_branches}") # Ensure that unpaired branches start at index I where I is the number of paired branches. branch_label = branch_image.max() # Add the unpaired branches back to the branch image for i in unpaired_branches: branch_label += 1 branch_image[ordered_branches[i][:, 0], ordered_branches[i][:, 1]] = branch_label return branch_image, avg_image
[docs] @staticmethod def analyse_node_branches( p_to_nm: np.float64, reduced_node_area: npt.NDArray[np.int32], branch_start_coords: npt.NDArray[np.int32], max_length_px: np.float64, reduced_skeleton_graph: nx.classes.graph.Graph, image: npt.NDArray[np.number], average_trace_advised: bool, node_coord: tuple[np.int32, np.int32], pair_odd_branches: bool, filename: str, resolution_threshold: np.float64, ) -> tuple[ npt.NDArray[np.int32], dict[int, MatchedBranch], list[npt.NDArray[np.int32]], dict[int, dict[str, npt.NDArray[np.bool_]]], npt.NDArray[np.int32], np.float64 | None, ]: """ Analyse the branches of a single node. Parameters ---------- p_to_nm : np.float64 The pixel to nm scaling factor. reduced_node_area : npt.NDArray[np.int32] An NxM numpy array of the node in question and the branches connected to it. Node is marked by 3, and branches by 1. branch_start_coords : npt.NDArray[np.int32] An Nx2 numpy array of the coordinates of the branches connected to the node. max_length_px : np.int32 The maximum length in pixels to traverse along while ordering. reduced_skeleton_graph : nx.classes.graph.Graph The graph representation of the reduced node area. image : npt.NDArray[np.number] The full image of the grain. average_trace_advised : bool Flag to determine whether to use the average trace. node_coord : tuple[np.int32, np.int32] The node coordinates. pair_odd_branches : bool Whether to try and pair odd-branched nodes. filename : str The filename of the image. resolution_threshold : np.float64 The resolution threshold below which to warn the user that the node is difficult to analyse. Returns ------- pairs: npt.NDArray[np.int32] Nx2 numpy array of pairs of branches that are matched through a node. matched_branches: dict[int, MatchedBranch]] Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "ordered_coords" : npt.NDArray[np.int32]. - "heights" : npt.NDArray[np.number]. Heights of the branches. - "distances" : npt.NDArray[np.number]. The accumulating distance along the branch. - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches. - "angles" : np.float64. The angle of the branch, added in later steps. ordered_branches: list[npt.NDArray[np.int32]] List of numpy arrays of ordered branch coordinates. masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]] Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches. branch_under_over_order: npt.NDArray[np.int32] The order of the branches based on the FWHM. confidence: np.float64 | None The confidence of the crossing. Optional. """ if not p_to_nm <= resolution_threshold: LOGGER.debug(f"Resolution {p_to_nm} is below suggested {resolution_threshold}, node difficult to analyse.") # Pixel-wise order the branches coming from the node and calculate the starting vector for each branch ordered_branches, singlet_branch_vectors = nodeStats.get_ordered_branches_and_vectors( reduced_node_area, branch_start_coords, max_length_px ) # Pair the singlet branch vectors based on their suitability using vector orientation. if len(branch_start_coords) % 2 == 0 or pair_odd_branches: pairs = nodeStats.pair_vectors(np.asarray(singlet_branch_vectors)) else: pairs = np.array([], dtype=np.int32) # Match the branches up matched_branches, masked_image = nodeStats.join_matching_branches_through_node( pairs, ordered_branches, reduced_skeleton_graph, image, average_trace_advised, node_coord, filename, ) # Redo the FWHMs after the processing for more accurate determination of under/overs. hms = [] for _, values in matched_branches.items(): hms.append(values["fwhm"]["half_maxs"][2]) for _, values in matched_branches.items(): values["fwhm"] = nodeStats.calculate_fwhm(values["heights"], values["distances"], hm=max(hms)) # Get the confidence of the crossing crossing_fwhms = [] for _, values in matched_branches.items(): crossing_fwhms.append(values["fwhm"]["fwhm"]) if len(crossing_fwhms) <= 1: confidence = None else: crossing_fwhm_combinations = list(combinations(crossing_fwhms, 2)) confidence = np.float64(nodeStats.cross_confidence(crossing_fwhm_combinations)) # Order the branch indexes based on the FWHM of the branches. branch_under_over_order = np.array(list(matched_branches.keys()))[np.argsort(np.array(crossing_fwhms))] return ( pairs, matched_branches, ordered_branches, masked_image, branch_under_over_order, confidence, singlet_branch_vectors, )
[docs] @staticmethod def join_matching_branches_through_node( pairs: npt.NDArray[np.int32], ordered_branches: list[npt.NDArray[np.int32]], reduced_skeleton_graph: nx.classes.graph.Graph, image: npt.NDArray[np.number], average_trace_advised: bool, node_coords: tuple[np.int32, np.int32], filename: str, ) -> tuple[dict[int, MatchedBranch], dict[int, dict[str, npt.NDArray[np.bool_]]]]: """ Join branches that are matched through a node. Parameters ---------- pairs : npt.NDArray[np.int32] Nx2 numpy array of pairs of branches that are matched through a node. ordered_branches : list[npt.NDArray[np.int32]] List of numpy arrays of ordered branch coordinates. reduced_skeleton_graph : nx.classes.graph.Graph Graph representation of the skeleton. image : npt.NDArray[np.number] The full image of the grain. average_trace_advised : bool Flag to determine whether to use the average trace. node_coords : tuple[np.int32, np.int32] The node coordinates. filename : str The filename of the image. Returns ------- matched_branches: dict[int, dict[str, npt.NDArray[np.number]]] Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "ordered_coords" : npt.NDArray[np.int32]. - "heights" : npt.NDArray[np.number]. Heights of the branches. - "distances" : - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches. masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]] Dictionary where the key is the index of the pair and the value is a dictionary containing the following keys: - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches. """ matched_branches: dict[int, MatchedBranch] = {} masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]] = ( {} ) # Masked image is a dictionary of pairs of branches for i, (branch_1, branch_2) in enumerate(pairs): matched_branches[i] = MatchedBranch( ordered_coords=np.array([], dtype=np.int32), heights=np.array([], dtype=np.float64), distances=np.array([], dtype=np.float64), fwhm={}, angles=None, ) masked_image[i] = {} # find close ends by rearranging branch coords branch_1_coords, branch_2_coords = nodeStats.order_branches( ordered_branches[branch_1], ordered_branches[branch_2] ) # Get graphical shortest path between branch ends on the skeleton crossing = nx.shortest_path( reduced_skeleton_graph, source=tuple(branch_1_coords[-1]), target=tuple(branch_2_coords[0]), weight="weight", ) crossing = np.asarray(crossing[1:-1]) # remove start and end points & turn into array # Branch coords and crossing if crossing.shape == (0,): branch_coords = np.vstack([branch_1_coords, branch_2_coords]) else: branch_coords = np.vstack([branch_1_coords, crossing, branch_2_coords]) # make images of single branch joined and multiple branches joined single_branch_img: npt.NDArray[np.bool_] = np.zeros_like(image).astype(bool) single_branch_img[branch_coords[:, 0], branch_coords[:, 1]] = True single_branch_coords = order_branch(single_branch_img.astype(bool), [0, 0]) # calc image-wide coords matched_branches[i]["ordered_coords"] = single_branch_coords # get heights and trace distance of branch try: assert average_trace_advised distances, heights, mask, _ = nodeStats.average_height_trace( image, single_branch_img, single_branch_coords, [node_coords[0], node_coords[1]] ) masked_image[i]["avg_mask"] = mask except ( AssertionError, IndexError, ) as e: # Assertion - avg trace not advised, Index - wiggy branches LOGGER.debug(f"[{filename}] : avg trace failed with {e}, single trace only.") average_trace_advised = False distances = nodeStats.coord_dist_rad(single_branch_coords, np.array([node_coords[0], node_coords[1]])) # distances = self.coord_dist(single_branch_coords) zero_dist = distances[ np.argmin( np.sqrt( (single_branch_coords[:, 0] - node_coords[0]) ** 2 + (single_branch_coords[:, 1] - node_coords[1]) ** 2 ) ) ] heights = image[single_branch_coords[:, 0], single_branch_coords[:, 1]] # self.hess distances = distances - zero_dist distances, heights = nodeStats.average_uniques( distances, heights ) # needs to be paired with coord_dist_rad matched_branches[i]["heights"] = heights matched_branches[i]["distances"] = distances # identify over/under matched_branches[i]["fwhm"] = nodeStats.calculate_fwhm(heights, distances) return matched_branches, masked_image
[docs] @staticmethod def get_ordered_branches_and_vectors( reduced_node_area: npt.NDArray[np.int32], branch_start_coords: npt.NDArray[np.int32], max_length_px: np.float64, ) -> tuple[list[npt.NDArray[np.int32]], list[npt.NDArray[np.int32]]]: """ Get ordered branches and vectors for a node. Branches are ordered so they are no longer just a disordered set of coordinates, and vectors are calculated to represent the general direction tendency of the branch, this allows for alignment matching later on. Parameters ---------- reduced_node_area : npt.NDArray[np.int32] An NxM numpy array of the node in question and the branches connected to it. Node is marked by 3, and branches by 1. branch_start_coords : npt.NDArray[np.int32] An Px2 numpy array of coordinates representing the start of branches where P is the number of branches. max_length_px : np.int32 The maximum length in pixels to traverse along while ordering. Returns ------- tuple[list[npt.NDArray[np.int32]], list[npt.NDArray[np.int32]]] A tuple containing a list of ordered branches and a list of vectors. """ ordered_branches = [] vectors = [] nodeless = np.where(reduced_node_area == 1, 1, 0) for branch_start_coord in branch_start_coords: # Order the branch coordinates so they're no longer just a disordered set of coordinates ordered_branch = order_branch_from_start(nodeless.copy(), branch_start_coord, max_length=max_length_px) ordered_branches.append(ordered_branch) # Calculate vector to represent the general direction tendency of the branch (for alignment matching) vector = nodeStats.get_vector(ordered_branch, branch_start_coord) vectors.append(vector) return ordered_branches, vectors
[docs] @staticmethod def cross_confidence(pair_combinations: list) -> float: """ Obtain the average confidence of the combinations using a reciprical function. Parameters ---------- pair_combinations : list List of length 2 combinations of FWHM values. Returns ------- float The average crossing confidence. """ c = 0 for pair in pair_combinations: c += nodeStats.recip(pair) return c / len(pair_combinations)
[docs] @staticmethod def recip(vals: list) -> float: """ Compute 1 - (max / min) of the two values provided. Parameters ---------- vals : list List of 2 values. Returns ------- float Result of applying the 1-(min / max) function to the two values. """ try: if min(vals) == 0: # means fwhm variation hasn't worked return 0 return 1 - min(vals) / max(vals) except ZeroDivisionError: return 0
[docs] @staticmethod def get_vector(coords: npt.NDArray, origin: npt.NDArray) -> npt.NDArray: """ Calculate the normalised vector of the coordinate means in a branch. Parameters ---------- coords : npt.NDArray 2xN array of x, y coordinates. origin : npt.NDArray 2x1 array of an x, y coordinate. Returns ------- npt.NDArray Normalised vector from origin to the mean coordinate. """ vector = coords.mean(axis=0) - origin norm = np.sqrt(vector @ vector) return vector if norm == 0 else vector / norm # normalise vector so length=1
[docs] @staticmethod def calc_angles(vectors: npt.NDArray) -> npt.NDArray[np.float64]: """ Calculate the angles between vectors in an array. Uses the formula: .. code-block:: RST cos(theta) = |a|•|b|/|a||b| Parameters ---------- vectors : npt.NDArray Array of 2x1 vectors. Returns ------- npt.NDArray An array of the cosine of the angles between the vectors. """ dot = vectors @ vectors.T norm = np.diag(dot) ** 0.5 cos_angles = dot / (norm.reshape(-1, 1) @ norm.reshape(1, -1)) np.fill_diagonal(cos_angles, 1) # ensures vector_x • vector_x angles are 0 return abs(np.arccos(cos_angles) / np.pi * 180) # angles in degrees
[docs] @staticmethod def pair_vectors(vectors: npt.NDArray) -> npt.NDArray[np.int32]: """ Take a list of vectors and pairs them based on the angle between them. Parameters ---------- vectors : npt.NDArray Array of 2x1 vectors to be paired. Returns ------- npt.NDArray An array of the matching pair indices. """ # calculate cosine of angle angles = nodeStats.calc_angles(vectors) # match angles return nodeStats.best_matches(angles)
[docs] @staticmethod def best_matches(arr: npt.NDArray, max_weight_matching: bool = True) -> npt.NDArray: """ Turn a matrix into a graph and calculates the best matching index pairs. Parameters ---------- arr : npt.NDArray Transpose symmetric MxM array where the value of index i, j represents a weight between i and j. max_weight_matching : bool Whether to obtain best matching pairs via maximum weight, or minimum weight matching. Returns ------- npt.NDArray Array of pairs of indexes. """ if max_weight_matching: G = nodeStats.create_weighted_graph(arr) matching = np.array(list(nx.max_weight_matching(G, maxcardinality=True))) else: np.fill_diagonal(arr, arr.max() + 1) G = nodeStats.create_weighted_graph(arr) matching = np.array(list(nx.min_weight_matching(G))) return matching
[docs] @staticmethod def create_weighted_graph(matrix: npt.NDArray) -> nx.Graph: """ Create a bipartite graph connecting i <-> j from a square matrix of weights matrix[i, j]. Parameters ---------- matrix : npt.NDArray Square array of weights between rows and columns. Returns ------- nx.Graph Bipatrite graph with edge weight i->j matching matrix[i,j]. """ n = len(matrix) G = nx.Graph() for i in range(n): for j in range(i + 1, n): G.add_edge(i, j, weight=matrix[i, j]) return G
[docs] @staticmethod def pair_angles(angles: npt.NDArray) -> list: """ Pair angles that are 180 degrees to each other and removes them before selecting the next pair. Parameters ---------- angles : npt.NDArray Square array (i,j) of angles between i and j. Returns ------- list A list of paired indexes in a list. """ angles_cp = angles.copy() pairs = [] for _ in range(int(angles.shape[0] / 2)): pair = np.unravel_index(np.argmax(angles_cp), angles.shape) pairs.append(pair) # add to list angles_cp[[pair]] = 0 # set rows 0 to avoid picking again angles_cp[:, [pair]] = 0 # set cols 0 to avoid picking again return np.asarray(pairs)
[docs] @staticmethod def gaussian(x: npt.NDArray, h: float, mean: float, sigma: float): """ Apply the gaussian function. Parameters ---------- x : npt.NDArray X values to be passed into the gaussian. h : float The peak height of the gaussian. mean : float The mean of the x values. sigma : float The standard deviation of the image. Returns ------- npt.NDArray The y-values of the gaussian performed on the x values. """ return h * np.exp(-((x - mean) ** 2) / (2 * sigma**2))
[docs] @staticmethod def interpolate_between_yvalue(x: npt.NDArray, y: npt.NDArray, yvalue: float) -> float: """ Calculate the x value between the two points either side of yvalue in y. Parameters ---------- x : npt.NDArray An array of length y. y : npt.NDArray An array of length x. yvalue : float A value within the bounds of the y array. Returns ------- float The linearly interpolated x value between the arrays. """ for i in range(len(y) - 1): if y[i] <= yvalue <= y[i + 1] or y[i + 1] <= yvalue <= y[i]: # if points cross through the hm value return nodeStats.lin_interp([x[i], y[i]], [x[i + 1], y[i + 1]], yvalue=yvalue) return 0
[docs] @staticmethod def calculate_fwhm( heights: npt.NDArray, distances: npt.NDArray, hm: float | None = None ) -> dict[str, np.float64 | list[np.float64 | float | None]]: """ Calculate the FWHM value. First identifyies the HM then finding the closest values in the distances array and using linear interpolation to calculate the FWHM. Parameters ---------- heights : npt.NDArray Array of heights. distances : npt.NDArray Array of distances. hm : Union[None, float], optional The halfmax value to match (if wanting the same HM between curves), by default None. Returns ------- tuple[float, list, list] The FWHM value, [distance at hm for 1st half of trace, distance at hm for 2nd half of trace, HM value], [index of the highest point, distance at highest point, height at highest point]. """ centre_fraction = int(len(heights) * 0.2) # in case zone approaches another node, look around centre for max if centre_fraction == 0: high_idx = np.argmax(heights) else: high_idx = np.argmax(heights[centre_fraction:-centre_fraction]) + centre_fraction # get array halves to find first points that cross hm arr1 = heights[:high_idx][::-1] dist1 = distances[:high_idx][::-1] arr2 = heights[high_idx:] dist2 = distances[high_idx:] if hm is None: # Get half max hm = (heights.max() - heights.min()) / 2 + heights.min() # half max value -> try to make it the same as other crossing branch? # increase make hm = lowest of peak if it doesn’t hit one side if np.min(arr1) > hm: arr1_local_min = argrelextrema(arr1, np.less)[-1] # closest to end try: hm = arr1[arr1_local_min][0] except IndexError: # index error when no local minima hm = np.min(arr1) elif np.min(arr2) > hm: arr2_local_min = argrelextrema(arr2, np.less)[0] # closest to start try: hm = arr2[arr2_local_min][0] except IndexError: # index error when no local minima hm = np.min(arr2) arr1_hm = nodeStats.interpolate_between_yvalue(x=dist1, y=arr1, yvalue=hm) arr2_hm = nodeStats.interpolate_between_yvalue(x=dist2, y=arr2, yvalue=hm) fwhm = np.float64(abs(arr2_hm - arr1_hm)) return { "fwhm": fwhm, "half_maxs": [arr1_hm, arr2_hm, hm], "peaks": [high_idx, distances[high_idx], heights[high_idx]], }
[docs] @staticmethod def lin_interp(point_1: list, point_2: list, xvalue: float | None = None, yvalue: float | None = None) -> float: """ Linear interp 2 points by finding line equation and subbing. Parameters ---------- point_1 : list List of an x and y coordinate. point_2 : list List of an x and y coordinate. xvalue : Union[float, None], optional Value at which to interpolate to get a y coordinate, by default None. yvalue : Union[float, None], optional Value at which to interpolate to get an x coordinate, by default None. Returns ------- float Value of x or y linear interpolation. """ m = (point_1[1] - point_2[1]) / (point_1[0] - point_2[0]) c = point_1[1] - (m * point_1[0]) if xvalue is not None: return m * xvalue + c # interp_y if yvalue is not None: return (yvalue - c) / m # interp_x raise ValueError
[docs] @staticmethod def order_branches(branch1: npt.NDArray, branch2: npt.NDArray) -> tuple: """ Order the two ordered arrays based on the closest endpoint coordinates. Parameters ---------- branch1 : npt.NDArray An Nx2 array describing coordinates. branch2 : npt.NDArray An Nx2 array describing coordinates. Returns ------- tuple An tuple with the each coordinate array ordered to follow on from one-another. """ endpoints1 = np.asarray([branch1[0], branch1[-1]]) endpoints2 = np.asarray([branch2[0], branch2[-1]]) sum1 = abs(endpoints1 - endpoints2).sum(axis=1) sum2 = abs(endpoints1[::-1] - endpoints2).sum(axis=1) if sum1.min() < sum2.min(): if np.argmin(sum1) == 0: return branch1[::-1], branch2 return branch1, branch2[::-1] if np.argmin(sum2) == 0: return branch1, branch2 return branch1[::-1], branch2[::-1]
[docs] @staticmethod def binary_line(start: npt.NDArray, end: npt.NDArray) -> npt.NDArray: """ Create a binary path following the straight line between 2 points. Parameters ---------- start : npt.NDArray A coordinate. end : npt.NDArray Another coordinate. Returns ------- npt.NDArray An Nx2 coordinate array that the line passes through. """ arr = [] m_swap = False x_swap = False slope = (end - start)[1] / (end - start)[0] if abs(slope) > 1: # swap x and y if slope will cause skips start, end = start[::-1], end[::-1] slope = 1 / slope m_swap = True if start[0] > end[0]: # swap x coords if coords wrong way around start, end = end, start x_swap = True # code assumes slope < 1 hence swap x_start, y_start = start x_end, _ = end for x in range(x_start, x_end + 1): y_true = slope * (x - x_start) + y_start y_pixel = np.round(y_true) arr.append([x, y_pixel]) if m_swap: # if swapped due to slope, return arr = np.asarray(arr)[:, [1, 0]].reshape(-1, 2).astype(int) if x_swap: return arr[::-1] return arr arr = np.asarray(arr).reshape(-1, 2).astype(int) if x_swap: return arr[::-1] return arr
[docs] @staticmethod def coord_dist_rad(coords: npt.NDArray, centre: npt.NDArray, pixel_to_nm_scaling: float = 1) -> npt.NDArray: """ Calculate the distance from the centre coordinate to a point along the ordered coordinates. This differs to traversal along the coordinates taken. This also averages any common distance values and makes those in the trace before the node index negative. Parameters ---------- coords : npt.NDArray Nx2 array of branch coordinates. centre : npt.NDArray A 1x2 array of the centre coordinates to identify a 0 point for the node. pixel_to_nm_scaling : float, optional The pixel to nanometer scaling factor to provide real units, by default 1. Returns ------- npt.NDArray A Nx1 array of the distance from the node centre. """ diff_coords = coords - centre if np.all(coords == centre, axis=1).sum() == 0: # if centre not in coords, reassign centre diff_dists = np.sqrt(diff_coords[:, 0] ** 2 + diff_coords[:, 1] ** 2) centre = coords[np.argmin(diff_dists)] cross_idx = np.argwhere(np.all(coords == centre, axis=1)) rad_dist = np.sqrt(diff_coords[:, 0] ** 2 + diff_coords[:, 1] ** 2) rad_dist[0 : cross_idx[0][0]] *= -1 return rad_dist * pixel_to_nm_scaling
[docs] @staticmethod def above_below_value_idx(array: npt.NDArray, value: float) -> list: """ Identify indices of the array neighbouring the specified value. Parameters ---------- array : npt.NDArray Array of values. value : float Value to identify indices between. Returns ------- list List of the lower index and higher index around the value. Raises ------ IndexError When the value is in the array. """ idx1 = abs(array - value).argmin() try: if array[idx1] < value < array[idx1 + 1]: idx2 = idx1 + 1 elif array[idx1 - 1] < value < array[idx1]: idx2 = idx1 - 1 else: raise IndexError # this will be if the number is the same indices = [idx1, idx2] indices.sort() return indices except IndexError: return None
[docs] @staticmethod def average_height_trace( # noqa: C901 img: npt.NDArray, branch_mask: npt.NDArray, branch_coords: npt.NDArray, centre=(0, 0) ) -> tuple: """ Average two side-by-side ordered skeleton distance and height traces. Dilate the original branch to create two additional side-by-side branches in order to get a more accurate average of the height traces. This function produces the common distances between these 3 branches, and their averaged heights. Parameters ---------- img : npt.NDArray An array of numbers pertaining to an image. branch_mask : npt.NDArray A binary array of the branch, must share the same dimensions as the image. branch_coords : npt.NDArray Ordered coordinates of the branch mask. centre : Union[float, None] The coordinates to centre the branch around. Returns ------- tuple A tuple of the averaged heights from the linetrace and their corresponding distances from the crossing. """ # get heights and dists of the original (middle) branch branch_dist = nodeStats.coord_dist_rad(branch_coords, centre) # branch_dist = self.coord_dist(branch_coords) branch_heights = img[branch_coords[:, 0], branch_coords[:, 1]] branch_dist, branch_heights = nodeStats.average_uniques( branch_dist, branch_heights ) # needs to be paired with coord_dist_rad dist_zero_point = branch_dist[ np.argmin(np.sqrt((branch_coords[:, 0] - centre[0]) ** 2 + (branch_coords[:, 1] - centre[1]) ** 2)) ] branch_dist_norm = branch_dist - dist_zero_point # - 0 # branch_dist[branch_heights.argmax()] # want to get a 3 pixel line trace, one on each side of orig dilate = binary_dilation(branch_mask, iterations=1) dilate = nodeStats.fill_holes(dilate) dilate_minus = np.where(dilate != branch_mask, 1, 0) dilate2 = binary_dilation(dilate, iterations=1) dilate2[(dilate == 1) | (branch_mask == 1)] = 0 labels = label(dilate2) # Cleanup stages - re-entering, early terminating, closer traces # if parallel trace out and back in zone, can get > 2 labels labels = nodeStats._remove_re_entering_branches(labels, remaining_branches=2) # if parallel trace doesn't exit window, can get 1 label # occurs when skeleton has poor connections (extra branches which cut corners) if labels.max() == 1: conv = convolve_skeleton(branch_mask) endpoints = np.argwhere(conv == 2) for endpoint in endpoints: # may be >1 endpoint para_trace_coords = np.argwhere(labels == 1) abs_diff = np.absolute(para_trace_coords - endpoint).sum(axis=1) min_idxs = np.where(abs_diff == abs_diff.min()) trace_coords_remove = para_trace_coords[min_idxs] labels[trace_coords_remove[:, 0], trace_coords_remove[:, 1]] = 0 labels = label(labels) # reduce binary dilation distance parallel = np.zeros_like(branch_mask).astype(np.int32) for i in range(1, labels.max() + 1): single = labels.copy() single[single != i] = 0 single[single == i] = 1 sing_dil = binary_dilation(single) parallel[(sing_dil == dilate_minus) & (sing_dil == 1)] = i labels = parallel.copy() binary = labels.copy() binary[binary != 0] = 1 binary += branch_mask # get and order coords, then get heights and distances relitive to node centre / highest point heights = [] distances = [] for i in np.unique(labels)[1:]: trace_img = np.where(labels == i, 1, 0) trace_img = getSkeleton(img, trace_img, method="zhang").get_skeleton() trace = order_branch(trace_img, branch_coords[0]) height_trace = img[trace[:, 0], trace[:, 1]] dist = nodeStats.coord_dist_rad(trace, centre) # self.coord_dist(trace) dist, height_trace = nodeStats.average_uniques(dist, height_trace) # needs to be paired with coord_dist_rad heights.append(height_trace) distances.append( dist - dist_zero_point # - 0 ) # branch_dist[branch_heights.argmax()]) #dist[central_heights.argmax()]) # Make like coord system using original branch avg1 = [] avg2 = [] for mid_dist in branch_dist_norm: for i, (distance, height) in enumerate(zip(distances, heights)): # check if distance already in traces array if (mid_dist == distance).any(): idx = np.where(mid_dist == distance) if i == 0: avg1.append([mid_dist, height[idx][0]]) else: avg2.append([mid_dist, height[idx][0]]) # if not, linearly interpolate the mid-branch value else: # get index after and before the mid branches' x coord xidxs = nodeStats.above_below_value_idx(distance, mid_dist) if xidxs is None: pass # if indexes outside of range, pass else: point1 = [distance[xidxs[0]], height[xidxs[0]]] point2 = [distance[xidxs[1]], height[xidxs[1]]] y = nodeStats.lin_interp(point1, point2, xvalue=mid_dist) if i == 0: avg1.append([mid_dist, y]) else: avg2.append([mid_dist, y]) avg1 = np.asarray(avg1) avg2 = np.asarray(avg2) # ensure arrays are same length to average temp_x = branch_dist_norm[np.isin(branch_dist_norm, avg1[:, 0])] common_dists = avg2[:, 0][np.isin(avg2[:, 0], temp_x)] common_avg_branch_heights = branch_heights[np.isin(branch_dist_norm, common_dists)] common_avg1_heights = avg1[:, 1][np.isin(avg1[:, 0], common_dists)] common_avg2_heights = avg2[:, 1][np.isin(avg2[:, 0], common_dists)] average_heights = (common_avg_branch_heights + common_avg1_heights + common_avg2_heights) / 3 return ( common_dists, average_heights, binary, [[heights[0], branch_heights, heights[1]], [distances[0], branch_dist_norm, distances[1]]], )
[docs] @staticmethod def fill_holes(mask: npt.NDArray) -> npt.NDArray: """ Fill all holes within a binary mask. Parameters ---------- mask : npt.NDArray Binary array of object. Returns ------- npt.NDArray Binary array of object with any interior holes filled in. """ inv_mask = np.where(mask != 0, 0, 1) lbl_inv = label(inv_mask, connectivity=1) idxs, counts = np.unique(lbl_inv, return_counts=True) max_idx = idxs[np.argmax(counts)] return np.where(lbl_inv != max_idx, 1, 0)
[docs] @staticmethod def _remove_re_entering_branches(mask: npt.NDArray, remaining_branches: int = 1) -> npt.NDArray: """ Remove smallest branches which branches exit and re-enter the viewing area. Contninues until only <remaining_branches> remain. Parameters ---------- mask : npt.NDArray Skeletonised binary mask of an object. remaining_branches : int, optional Number of objects (branches) to keep, by default 1. Returns ------- npt.NDArray Mask with only a single skeletonised branch. """ rtn_image = mask.copy() binary_image = mask.copy() binary_image[binary_image != 0] = 1 labels = label(binary_image) if labels.max() > remaining_branches: lens = [labels[labels == i].size for i in range(1, labels.max() + 1)] while len(lens) > remaining_branches: smallest_idx = min(enumerate(lens), key=lambda x: x[1])[0] rtn_image[labels == smallest_idx + 1] = 0 lens.remove(min(lens)) return rtn_image
[docs] @staticmethod def only_centre_branches(node_image: npt.NDArray, node_coordinate: npt.NDArray) -> npt.NDArray[np.int32]: """ Remove all branches not connected to the current node. Parameters ---------- node_image : npt.NDArray An image of the skeletonised area surrounding the node where the background = 0, skeleton = 1, termini = 2, nodes = 3. node_coordinate : npt.NDArray 2x1 coordinate describing the position of a node. Returns ------- npt.NDArray[np.int32] The initial node image but only with skeletal branches connected to the middle node. """ node_image_cp = node_image.copy() # get node-only image nodes = node_image_cp.copy() nodes[nodes != 3] = 0 labeled_nodes = label(nodes) # find which cluster is closest to the centre node_coords = np.argwhere(nodes == 3) min_coords = node_coords[abs(node_coords - node_coordinate).sum(axis=1).argmin()] centre_idx = labeled_nodes[min_coords[0], min_coords[1]] # get nodeless image nodeless = node_image_cp.copy() nodeless = np.where( (node_image == 1) | (node_image == 2), 1, 0 ) # if termini, need this in the labeled branches too nodeless[labeled_nodes == centre_idx] = 1 # return centre node labeled_nodeless = label(nodeless) # apply to return image for i in range(1, labeled_nodeless.max() + 1): if (node_image_cp[labeled_nodeless == i] == 3).any(): node_image_cp[labeled_nodeless != i] = 0 break # remove small area around other nodes labeled_nodes[labeled_nodes == centre_idx] = 0 non_central_node_coords = np.argwhere(labeled_nodes != 0) for coord in non_central_node_coords: for j, coord_val in enumerate(coord): if coord_val - 1 < 0: coord[j] = 1 if coord_val + 2 > node_image_cp.shape[j]: coord[j] = node_image_cp.shape[j] - 2 node_image_cp[coord[0] - 1 : coord[0] + 2, coord[1] - 1 : coord[1] + 2] = 0 return node_image_cp
[docs] @staticmethod def average_uniques(arr1: npt.NDArray, arr2: npt.NDArray) -> tuple: """ Obtain the unique values of both arrays, and the average of common values. Parameters ---------- arr1 : npt.NDArray An array. arr2 : npt.NDArray An array. Returns ------- tuple The unique values of both arrays, and the averaged common values. """ arr1_uniq, index = np.unique(arr1, return_index=True) arr2_new = np.zeros_like(arr1_uniq).astype(np.float64) for i, val in enumerate(arr1[index]): mean = arr2[arr1 == val].mean() arr2_new[i] += mean return arr1[index], arr2_new
[docs] @staticmethod def average_crossing_confs(node_dict) -> None | float: """ Return the average crossing confidence of all crossings in the molecule. Parameters ---------- node_dict : dict A dictionary containing node statistics and information. Returns ------- Union[None, float] The value of minimum confidence or none if not possible. """ sum_conf = 0 valid_confs = 0 for _, (_, values) in enumerate(node_dict.items()): confidence = values["confidence"] if confidence is not None: sum_conf += confidence valid_confs += 1 try: return sum_conf / valid_confs except ZeroDivisionError: return None
[docs] @staticmethod def minimum_crossing_confs(node_dict: dict) -> None | float: """ Return the minimum crossing confidence of all crossings in the molecule. Parameters ---------- node_dict : dict A dictionary containing node statistics and information. Returns ------- Union[None, float] The value of minimum confidence or none if not possible. """ confidences = [] valid_confs = 0 for _, (_, values) in enumerate(node_dict.items()): confidence = values["confidence"] if confidence is not None: confidences.append(confidence) valid_confs += 1 try: return min(confidences) except ValueError: return None
[docs] def compile_metrics(self) -> None: """Add the number of crossings, average and minimum crossing confidence to the metrics dictionary.""" self.metrics["num_crossings"] = np.int64((self.node_centre_mask == 3).sum()) self.metrics["avg_crossing_confidence"] = np.float64(nodeStats.average_crossing_confs(self.node_dicts)) self.metrics["min_crossing_confidence"] = np.float64(nodeStats.minimum_crossing_confs(self.node_dicts))
[docs] def nodestats_image( image: npt.NDArray, disordered_tracing_direction_data: dict, filename: str, pixel_to_nm_scaling: float, node_joining_length: float, node_extend_dist: float, branch_pairing_length: float, pair_odd_branches: float, pad_width: int, ) -> tuple: """ Initialise the nodeStats class. Parameters ---------- image : npt.NDArray The array of pixels. disordered_tracing_direction_data : dict The images and bbox coordinates of the pruned skeletons. filename : str The name of the file being processed. For logging purposes. pixel_to_nm_scaling : float The pixel to nm scaling factor. node_joining_length : float The length over which to join skeletal intersections to be counted as one crossing. node_joining_length : float The distance over which to join nearby odd-branched nodes. node_extend_dist : float The distance under which to join odd-branched node regions. branch_pairing_length : float The length from the crossing point to pair and trace, obtaining FWHM's. pair_odd_branches : bool Whether to try and pair odd-branched nodes. pad_width : int The number of edge pixels to pad the image by. Returns ------- tuple[dict, pd.DataFrame, dict, dict] The nodestats statistics for each crossing, crossing statistics to be added to the grain statistics, an image dictionary of nodestats steps for the entire image, and single grain images. """ n_grains = len(disordered_tracing_direction_data) img_base = np.zeros_like(image) nodestats_data = {} # want to get each cropped image, use some anchor coords to match them onto the image, # and compile all the grain images onto a single image all_images = { "convolved_skeletons": img_base.copy(), "node_centres": img_base.copy(), "connected_nodes": img_base.copy(), } nodestats_branch_images = {} grainstats_additions = {} LOGGER.info(f"[{filename}] : Calculating NodeStats statistics for {n_grains} grains...") for n_grain, disordered_tracing_grain_data in disordered_tracing_direction_data.items(): nodestats = None # reset the nodestats variable try: nodestats = nodeStats( image=disordered_tracing_grain_data["original_image"], mask=disordered_tracing_grain_data["original_grain"], smoothed_mask=disordered_tracing_grain_data["smoothed_grain"], skeleton=disordered_tracing_grain_data["pruned_skeleton"], pixel_to_nm_scaling=pixel_to_nm_scaling, filename=filename, n_grain=n_grain, node_joining_length=node_joining_length, node_extend_dist=node_extend_dist, branch_pairing_length=branch_pairing_length, pair_odd_branches=pair_odd_branches, ) nodestats_dict, node_image_dict = nodestats.get_node_stats() LOGGER.debug(f"[{filename}] : Nodestats processed {n_grain} of {n_grains}") # compile images nodestats_images = { "convolved_skeletons": nodestats.conv_skelly, "node_centres": nodestats.node_centre_mask, "connected_nodes": nodestats.connected_nodes, } nodestats_branch_images[n_grain] = node_image_dict # compile metrics grainstats_additions[n_grain] = { "image": filename, "grain_number": int(n_grain.split("_")[-1]), } grainstats_additions[n_grain].update(nodestats.metrics) if nodestats_dict: # if the grain's nodestats dict is not empty nodestats_data[n_grain] = nodestats_dict # remap the cropped images back onto the original for image_name, full_image in all_images.items(): crop = nodestats_images[image_name] bbox = disordered_tracing_grain_data["bbox"] full_image[bbox[0] : bbox[2], bbox[1] : bbox[3]] += crop[pad_width:-pad_width, pad_width:-pad_width] except Exception as e: # pylint: disable=broad-exception-caught LOGGER.error( f"[{filename}] : Nodestats for {n_grain} failed. Consider raising an issue on GitHub. Error: ", exc_info=e, ) nodestats_data[n_grain] = {} # turn the grainstats additions into a dataframe, # might need to do something for when everything is empty grainstats_additions_df = pd.DataFrame.from_dict(grainstats_additions, orient="index") return nodestats_data, grainstats_additions_df, all_images, nodestats_branch_images