"""Generates disordered traces (pruned skeletons) and metrics."""
from __future__ import annotations
import logging
import warnings
import numpy as np
import numpy.typing as npt
import pandas as pd
import skan
import skimage.measure as skimage_measure
from scipy import ndimage
from skimage import filters
from skimage.morphology import label
from topostats.logs.logs import LOGGER_NAME
from topostats.tracing.pruning import prune_skeleton
from topostats.tracing.skeletonize import getSkeleton
from topostats.utils import convolve_skeleton
LOGGER = logging.getLogger(LOGGER_NAME)
# too-many-positional-arguments
# pylint: disable=R0917
[docs]
class disorderedTrace: # pylint: disable=too-many-instance-attributes
"""
Calculate disordered traces for a DNA molecule and calculates statistics from those traces.
Parameters
----------
image : npt.NDArray
Cropped image, typically padded beyond the bounding box.
mask : npt.NDArray
Labelled mask for the grain, typically padded beyond the bounding box.
filename : str
Filename being processed.
pixel_to_nm_scaling : float
Pixel to nm scaling.
min_skeleton_size : int
Minimum skeleton size below which tracing statistics are not calculated.
mask_smoothing_params : dict
Dictionary of parameters to smooth the grain mask for better quality skeletonisation results. Contains
a gaussian 'sigma' and number of dilation iterations.
skeletonisation_params : dict
Skeletonisation Parameters. Method of skeletonisation to use 'topostats' is the original TopoStats
method. Three methods from scikit-image are available 'zhang', 'lee' and 'thin'.
pruning_params : dict
Dictionary of pruning parameters. Contains 'method', 'max_length', 'height_threshold', 'method_values' and
'method_outlier'.
n_grain : int
Grain number being processed (only used in logging).
"""
def __init__( # pylint: disable=too-many-arguments
self,
image: npt.NDArray,
mask: npt.NDArray,
filename: str,
pixel_to_nm_scaling: float,
min_skeleton_size: int = 10,
mask_smoothing_params: dict | None = None,
skeletonisation_params: dict | None = None,
pruning_params: dict | None = None,
n_grain: int = None,
):
"""
Calculate disordered traces for a DNA molecule and calculates statistics from those traces.
Parameters
----------
image : npt.NDArray
Cropped image, typically padded beyond the bounding box.
mask : npt.NDArray
Labelled mask for the grain, typically padded beyond the bounding box.
filename : str
Filename being processed.
pixel_to_nm_scaling : float
Pixel to nm scaling.
min_skeleton_size : int
Minimum skeleton size below which tracing statistics are not calculated.
mask_smoothing_params : dict
Dictionary of parameters to smooth the grain mask for better quality skeletonisation results. Contains
a gaussian 'sigma' and number of dilation iterations.
skeletonisation_params : dict
Skeletonisation Parameters. Method of skeletonisation to use 'topostats' is the original TopoStats
method. Three methods from scikit-image are available 'zhang', 'lee' and 'thin'.
pruning_params : dict
Dictionary of pruning parameters. Contains 'method', 'max_length', 'height_threshold', 'method_values' and
'method_outlier'.
n_grain : int
Grain number being processed (only used in logging).
"""
self.image = image
self.mask = mask
self.filename = filename
self.pixel_to_nm_scaling = pixel_to_nm_scaling
self.min_skeleton_size = min_skeleton_size
self.mask_smoothing_params = mask_smoothing_params
self.skeletonisation_params = (
skeletonisation_params if skeletonisation_params is not None else {"method": "zhang"}
)
self.pruning_params = pruning_params if pruning_params is not None else {"method": "topostats"}
self.n_grain = n_grain
# Images
self.smoothed_mask = np.zeros_like(image)
self.skeleton = np.zeros_like(image)
self.pruned_skeleton = np.zeros_like(image)
# Trace
self.disordered_trace = None
# suppresses scipy splining warnings
warnings.filterwarnings("ignore")
LOGGER.debug(f"[{self.filename}] Performing Disordered Tracing")
[docs]
def trace_dna(self):
"""Perform the DNA skeletonisation and cleaning pipeline."""
self.smoothed_mask = self.smooth_mask(self.mask, **self.mask_smoothing_params)
self.skeleton = getSkeleton(
self.image,
self.smoothed_mask,
method=self.skeletonisation_params["method"],
height_bias=self.skeletonisation_params["height_bias"],
).get_skeleton()
self.pruned_skeleton = prune_skeleton(
self.image, self.skeleton, self.pixel_to_nm_scaling, **self.pruning_params.copy()
)
self.pruned_skeleton = self.remove_touching_edge(self.pruned_skeleton)
self.disordered_trace = np.argwhere(self.pruned_skeleton == 1)
if self.disordered_trace is None:
LOGGER.warning(f"[{self.filename}] : Grain {self.n_grain} failed to Skeletonise.")
self.disordered_trace = None
elif len(self.disordered_trace) < self.min_skeleton_size:
LOGGER.warning(f"[{self.filename}] : Grain {self.n_grain} skeleton < {self.min_skeleton_size}, skipping.")
self.disordered_trace = None
[docs]
def re_add_holes(
self,
orig_mask: npt.NDArray,
smoothed_mask: npt.NDArray,
holearea_min_max: tuple[float | int | None] = (2, None),
) -> npt.NDArray:
"""
Restore holes in masks that were occluded by dilation.
As Gaussian dilation smoothing methods can close holes in the original mask, this function obtains those holes
(based on the general background being the first due to padding) and adds them back into the smoothed mask. When
paired with ``smooth_mask``, this essentially just smooths the outer edge of the mask.
Parameters
----------
orig_mask : npt.NDArray
Original mask.
smoothed_mask : npt.NDArray
Original mask but with inner and outer edged smoothed. The smoothing operation may have closed up important
holes in the mask.
holearea_min_max : tuple[float | int | None]
Tuple of minimum and maximum hole area (in nanometers) to replace from the original mask into the smoothed
mask.
Returns
-------
npt.NDArray
Smoothed mask with holes restored.
"""
# handle none's
if set(holearea_min_max) == {None}:
return smoothed_mask
if None in holearea_min_max:
none_index = holearea_min_max.index(None)
holearea_min_max[none_index] = 0 if none_index == 0 else np.inf
# obtain px holesizes
holesize_min_px = holearea_min_max[0] / ((self.pixel_to_nm_scaling) ** 2)
holesize_max_px = holearea_min_max[1] / ((self.pixel_to_nm_scaling) ** 2)
# obtain a hole mask
holes = 1 - orig_mask
holes = label(holes)
hole_sizes = [holes[holes == i].size for i in range(1, holes.max() + 1)]
holes[holes == 1] = 0 # set background to 0 assuming it is the first hole seen (from top left)
# remove too small or too big holes from mask
for i, hole_size in enumerate(hole_sizes):
if hole_size < holesize_min_px or hole_size > holesize_max_px: # small holes may be fake are left out
holes[holes == i + 1] = 0
holes[holes != 0] = 1 # set correct sixe holes to 1
# replace correct sized holes
return np.where(holes == 1, 0, smoothed_mask)
[docs]
@staticmethod
def remove_touching_edge(skeleton: npt.NDArray) -> npt.NDArray:
"""
Remove any skeleton points touching the border (to prevent errors later).
Parameters
----------
skeleton : npt.NDArray
A binary array where touching clusters of 1's become 0's if touching the edge of the array.
Returns
-------
npt.NDArray
Skeleton without points touching the border.
"""
for edge in [skeleton[0, :-1], skeleton[:-1, -1], skeleton[-1, 1:], skeleton[1:, 0]]:
uniques = np.unique(edge)
for i in uniques:
skeleton[skeleton == i] = 0
return skeleton
[docs]
def smooth_mask(
self,
grain: npt.NDArray,
dilation_iterations: int = 2,
gaussian_sigma: float | int = 2,
holearea_min_max: tuple[int | float | None] = (0, None),
) -> npt.NDArray:
"""
Smooth a grain mask based on the lower number of binary pixels added from dilation or gaussian.
This method ensures gaussian smoothing isn't too aggressive and covers / creates gaps in the mask.
Parameters
----------
grain : npt.NDArray
Numpy array of the grain mask.
dilation_iterations : int
Number of times to dilate the grain to smooth it. Default is 2.
gaussian_sigma : float | None
Gaussian sigma value to smooth the grains after an Otsu threshold. If None, defaults to 2.
holearea_min_max : tuple[float | int | None]
Tuple of minimum and maximum hole area (in nanometers) to replace from the original mask into the smoothed
mask.
Returns
-------
npt.NDArray
Numpy array of smoothed image.
"""
# Option to disable the smoothing (i.e. U-Net masks are already smooth)
if dilation_iterations is None and gaussian_sigma is None:
LOGGER.debug(f"[{self.filename}] : no grain smoothing done")
return grain
# Option to only do gaussian or dilation
if dilation_iterations is not None:
dilation = ndimage.binary_dilation(grain, iterations=dilation_iterations).astype(np.int32)
else:
gauss = filters.gaussian(grain, sigma=gaussian_sigma)
gauss = np.where(gauss > filters.threshold_otsu(gauss) * 1.3, 1, 0)
gauss = gauss.astype(np.int32)
LOGGER.debug(f"[{self.filename}] : smoothing done by gaussian {gaussian_sigma}")
return self.re_add_holes(grain, gauss, holearea_min_max)
if gaussian_sigma is not None:
gauss = filters.gaussian(grain, sigma=gaussian_sigma)
gauss = np.where(gauss > filters.threshold_otsu(gauss) * 1.3, 1, 0)
gauss = gauss.astype(np.int32)
else:
LOGGER.debug(f"[{self.filename}] : smoothing done by dilation {dilation_iterations}")
return self.re_add_holes(grain, dilation, holearea_min_max)
# Competition option between dilation and gaussian mask differences wrt original grains
if abs(dilation.sum() - grain.sum()) > abs(gauss.sum() - grain.sum()):
LOGGER.debug(f"[{self.filename}] : smoothing done by gaussian {gaussian_sigma}")
return self.re_add_holes(grain, gauss, holearea_min_max)
LOGGER.debug(f"[{self.filename}] : smoothing done by dilation {dilation_iterations}")
return self.re_add_holes(grain, dilation, holearea_min_max)
[docs]
@staticmethod
def calculate_dna_width(
smoothed_mask: npt.NDArray, pruned_skeleton: npt.NDArray, pixel_to_nm_scaling: float = 1
) -> float:
"""
Calculate the mean width in metres of the DNA using the trace and mask.
Parameters
----------
smoothed_mask : npt.NDArray
Smoothed mask to be measured.
pruned_skeleton : npt.NDArray
Pruned skeleton.
pixel_to_nm_scaling : float
Scaling of pixels to nanometres.
Returns
-------
float
Width of grain in metres.
"""
dist_trans = ndimage.distance_transform_edt(smoothed_mask)
comb = np.where(pruned_skeleton == 1, dist_trans, 0)
return comb[comb != 0].mean() * 2 * pixel_to_nm_scaling
[docs]
def trace_image_disordered( # pylint: disable=too-many-arguments,too-many-locals
image: npt.NDArray,
grains_mask: npt.NDArray,
filename: str,
pixel_to_nm_scaling: float,
min_skeleton_size: int,
mask_smoothing_params: dict,
skeletonisation_params: dict,
pruning_params: dict,
pad_width: int = 1,
) -> tuple[dict, pd.DataFrame, dict, pd.DataFrame]:
"""
Processor function for tracing image.
Parameters
----------
image : npt.NDArray
Full image as Numpy Array.
grains_mask : npt.NDArray
Full image as Grains that are labelled.
filename : str
File being processed.
pixel_to_nm_scaling : float
Pixel to nm scaling.
min_skeleton_size : int
Minimum size of grain in pixels after skeletonisation.
mask_smoothing_params : dict
Dictionary of parameters to smooth the grain mask for better quality skeletonisation results. Contains
a gaussian 'sigma' and number of dilation iterations.
skeletonisation_params : dict
Dictionary of options for skeletonisation, options are 'zhang' (scikit-image) / 'lee' (scikit-image) / 'thin'
(scikitimage) or 'topostats' (original TopoStats method).
pruning_params : dict
Dictionary of options for pruning.
pad_width : int
Padding to the cropped image mask.
Returns
-------
tuple[dict, pd.DataFrame, dict, pd.DataFrame]
Binary and integer labeled cropped and full-image masks from skeletonising and pruning the grains in the image.
"""
# Check both arrays are the same shape - should this be a test instead, why should this ever occur?
if image.shape != grains_mask.shape:
raise ValueError(f"Image shape ({image.shape}) and Mask shape ({grains_mask.shape}) should match.")
cropped_images, cropped_masks, bboxs = prep_arrays(image, grains_mask, pad_width)
n_grains = len(cropped_images)
img_base = np.zeros_like(image)
disordered_trace_crop_data = {}
grainstats_additions = {}
disordered_tracing_stats = pd.DataFrame()
# want to get each cropped image, use some anchor coords to match them onto the image,
# and compile all the grain images onto a single image
all_images = {
"smoothed_grain": img_base.copy(),
"skeleton": img_base.copy(),
"pruned_skeleton": img_base.copy(),
"branch_indexes": img_base.copy(),
"branch_types": img_base.copy(),
}
LOGGER.info(f"[{filename}] : Calculating Disordered Tracing statistics for {n_grains} grains...")
for cropped_image_index, cropped_image in cropped_images.items():
try:
cropped_mask = cropped_masks[cropped_image_index]
disordered_trace_images = disordered_trace_grain(
cropped_image=cropped_image,
cropped_mask=cropped_mask,
pixel_to_nm_scaling=pixel_to_nm_scaling,
mask_smoothing_params=mask_smoothing_params,
skeletonisation_params=skeletonisation_params,
pruning_params=pruning_params,
filename=filename,
min_skeleton_size=min_skeleton_size,
n_grain=cropped_image_index,
)
LOGGER.debug(f"[{filename}] : Disordered Traced grain {cropped_image_index + 1} of {n_grains}")
if disordered_trace_images is not None:
# obtain segment stats
try:
skan_skeleton = skan.Skeleton(
np.where(disordered_trace_images["pruned_skeleton"] == 1, cropped_image, 0),
spacing=pixel_to_nm_scaling,
)
skan_df = skan.summarize(skan_skeleton, separator="_")
skan_df = compile_skan_stats(skan_df, skan_skeleton, cropped_image, filename, cropped_image_index)
total_branch_length = skan_df["branch_distance"].sum() * 1e-9
except ValueError:
LOGGER.warning(
f"[{filename}] : Skeleton for grain {cropped_image_index} has been pruned out of existence."
)
total_branch_length = 0
skan_df = pd.DataFrame()
disordered_tracing_stats = pd.concat((disordered_tracing_stats, skan_df))
# obtain stats
conv_pruned_skeleton = convolve_skeleton(disordered_trace_images["pruned_skeleton"])
grainstats_additions[cropped_image_index] = {
"image": filename,
"grain_number": cropped_image_index,
"grain_endpoints": np.int64((conv_pruned_skeleton == 2).sum()),
"grain_junctions": np.int64((conv_pruned_skeleton == 3).sum()),
"total_branch_lengths": total_branch_length,
"grain_width_mean": disorderedTrace.calculate_dna_width(
disordered_trace_images["smoothed_grain"],
disordered_trace_images["pruned_skeleton"],
pixel_to_nm_scaling,
)
* 1e-9,
}
# remap the cropped images back onto the original
for image_name, full_image in all_images.items():
crop = disordered_trace_images[image_name]
bbox = bboxs[cropped_image_index]
full_image[bbox[0] : bbox[2], bbox[1] : bbox[3]] += crop[pad_width:-pad_width, pad_width:-pad_width]
disordered_trace_crop_data[f"grain_{cropped_image_index}"] = disordered_trace_images
disordered_trace_crop_data[f"grain_{cropped_image_index}"]["bbox"] = bboxs[cropped_image_index]
disordered_trace_crop_data[f"grain_{cropped_image_index}"]["pad_width"] = pad_width
# when skel too small, pruned to 0's, skan -> ValueError -> skipped
except Exception as e: # pylint: disable=broad-exception-caught
LOGGER.error( # pylint: disable=logging-not-lazy
f"[{filename}] : Disordered tracing of grain "
f"{cropped_image_index} failed. Consider raising an issue on GitHub. Error: ",
exc_info=e,
)
# convert stats dict to dataframe
grainstats_additions_df = pd.DataFrame.from_dict(grainstats_additions, orient="index")
return disordered_trace_crop_data, grainstats_additions_df, all_images, disordered_tracing_stats
[docs]
def compile_skan_stats(
skan_df: pd.DataFrame, skan_skeleton: skan.Skeleton, image: npt.NDArray, filename: str, grain_number: int
) -> pd.DataFrame:
"""
Obtain and add more stats to the resultant Skan dataframe.
Parameters
----------
skan_df : pd.DataFrame
The statistics DataFrame produced by Skan's `summarize` function.
skan_skeleton : skan.Skeleton
The graphical representation of the skeleton produced by Skan.
image : npt.NDArray
The image the skeleton was produced from.
filename : str
Name of the file being processed.
grain_number : int
The number of the grain being processed.
Returns
-------
pd.DataFrame
A dataframe containing the filename, grain_number, branch-distance, branch-type, connected_segments,
mean-pixel-value, stdev-pixel-value, min-value, median-value, and mid-value.
"""
skan_df["image"] = filename
skan_df["branch-type"] = np.int64(skan_df["branch_type"])
skan_df["grain_number"] = grain_number
skan_df["connected_segments"] = skan_df.apply(find_connections, axis=1, skan_df=skan_df)
skan_df["min_value"] = skan_df.apply(lambda x: segment_heights(x, skan_skeleton, image).min(), axis=1)
skan_df["median_value"] = skan_df.apply(lambda x: np.median(segment_heights(x, skan_skeleton, image)), axis=1)
skan_df["middle_value"] = skan_df.apply(segment_middles, skan_skeleton=skan_skeleton, image=image, axis=1)
# remove unused skan columns
return skan_df[
[
"image",
"grain_number",
"branch_distance",
"branch_type",
"connected_segments",
"mean_pixel_value",
"stdev_pixel_value",
"min_value",
"median_value",
"middle_value",
]
]
[docs]
def segment_heights(row: pd.Series, skan_skeleton: skan.Skeleton, image: npt.NDArray) -> npt.NDArray:
"""
Obtain an ordered list of heights from the skan defined skeleton segment.
Parameters
----------
row : pd.Series
A row from the Skan summarize dataframe.
skan_skeleton : skan.Skeleton
The graphical representation of the skeleton produced by Skan.
image : npt.NDArray
The image the skeleton was produced from.
Returns
-------
npt.NDArray
Heights along the segment, naturally ordered by Skan.
"""
coords = skan_skeleton.path_coordinates(row.name)
return image[coords[:, 0], coords[:, 1]]
[docs]
def segment_middles(row: pd.Series, skan_skeleton: skan.csr.Skeleton, image: npt.NDArray) -> float:
"""
Obtain the pixel value in the middle of the ordered segment.
Parameters
----------
row : pd.Series
A row from the Skan summarize dataframe.
skan_skeleton : skan.csr.Skeleton
The graphical representation of the skeleton produced by Skan.
image : npt.NDArray
The image the skeleton was produced from.
Returns
-------
float
The single or mean pixel value corresponding to the middle coordinate(s) of the segment.
"""
heights = segment_heights(row, skan_skeleton, image)
middle_idx, middle_remainder = (len(heights) + 1) // 2 - 1, (len(heights) + 1) % 2
return heights[[middle_idx, middle_idx + middle_remainder]].mean()
[docs]
def find_connections(row: pd.Series, skan_df: pd.DataFrame) -> str:
"""
Compile the neighbouring branch indexes of the row.
Parameters
----------
row : pd.Series
A row from the Skan summarize dataframe.
skan_df : pd.DataFrame
The statistics DataFrame produced by Skan's `summarize` function.
Returns
-------
str
A string representation of a list of matching row indices where the node src and dst
columns match that of the rows.
String is needed for csv compatibility since csvs can't hold lists.
"""
connections = skan_df[
(skan_df["node_id_src"] == row["node_id_src"])
| (skan_df["node_id_dst"] == row["node_id_dst"])
| (skan_df["node_id_src"] == row["node_id_dst"])
| (skan_df["node_id_dst"] == row["node_id_src"])
].index.tolist()
# Remove the index of the current row itself from the list of connections
connections.remove(row.name)
return str(connections)
[docs]
def prep_arrays(
image: npt.NDArray, labelled_grains_mask: npt.NDArray, pad_width: int
) -> tuple[dict[int, npt.NDArray], dict[int, npt.NDArray]]:
"""
Take an image and labelled mask and crops individual grains and original heights to a list.
A second padding is made after cropping to ensure for "edge cases" where grains are close to bounding box edges that
they are traced correctly. This is accounted for when aligning traces to the whole image mask.
Parameters
----------
image : npt.NDArray
Gaussian filtered image. Typically filtered_image.images["gaussian_filtered"].
labelled_grains_mask : npt.NDArray
2D Numpy array of labelled grain masks, with each mask being comprised solely of unique integer (not
zero). Typically this will be output from 'grains.directions[<direction>["labelled_region_02]'.
pad_width : int
Cells by which to pad cropped regions by.
Returns
-------
Tuple
Returns a tuple of three dictionaries, the cropped images, cropped masks and bounding boxes.
"""
# Get bounding boxes for each grain
region_properties = skimage_measure.regionprops(labelled_grains_mask)
# Subset image and grains then zip them up
cropped_images = {}
cropped_masks = {}
# for index, grain in enumerate(region_properties):
# cropped_image, cropped_bbox = crop_array(image, grain.bbox, pad_width)
cropped_images = {index: crop_array(image, grain.bbox, pad_width) for index, grain in enumerate(region_properties)}
cropped_images = {index: np.pad(grain, pad_width=pad_width) for index, grain in cropped_images.items()}
cropped_masks = {
index: crop_array(labelled_grains_mask, grain.bbox, pad_width) for index, grain in enumerate(region_properties)
}
cropped_masks = {index: np.pad(grain, pad_width=pad_width) for index, grain in cropped_masks.items()}
cropped_masks = {index: np.where(grain == (index + 1), 1, 0) for index, grain in cropped_masks.items()}
# Get BBOX coords to remap crops to images
bboxs = [pad_bounding_box(image.shape, list(grain.bbox), pad_width=pad_width) for grain in region_properties]
return (cropped_images, cropped_masks, bboxs)
[docs]
def grain_anchor(array_shape: tuple, bounding_box: list, pad_width: int) -> list:
"""
Extract anchor (min_row, min_col) from labelled regions and align individual traces over the original image.
Parameters
----------
array_shape : tuple
Shape of original array.
bounding_box : list
A list of region properties returned by 'skimage.measure.regionprops()'.
pad_width : int
Padding for image.
Returns
-------
list(Tuple)
A list of tuples of the min_row, min_col of each bounding box.
"""
bounding_coordinates = pad_bounding_box(array_shape, bounding_box, pad_width)
return (bounding_coordinates[0], bounding_coordinates[1])
[docs]
def disordered_trace_grain( # pylint: disable=too-many-arguments
cropped_image: npt.NDArray,
cropped_mask: npt.NDArray,
pixel_to_nm_scaling: float,
mask_smoothing_params: dict,
skeletonisation_params: dict,
pruning_params: dict,
filename: str = None,
min_skeleton_size: int = 10,
n_grain: int = None,
) -> dict:
"""
Trace an individual grain.
Tracing involves multiple steps...
1. Skeletonisation
2. Pruning of side branches (artefacts from skeletonisation).
3. Ordering of the skeleton.
Parameters
----------
cropped_image : npt.NDArray
Cropped array from the original image defined as the bounding box from the labelled mask.
cropped_mask : npt.NDArray
Cropped array from the labelled image defined as the bounding box from the labelled mask. This should have been
converted to a binary mask.
pixel_to_nm_scaling : float
Pixel to nm scaling.
mask_smoothing_params : dict
Dictionary of parameters to smooth the grain mask for better quality skeletonisation results. Contains
a gaussian 'sigma' and number of dilation iterations.
skeletonisation_params : dict
Dictionary of skeletonisation parameters, options are 'zhang' (scikit-image) / 'lee' (scikit-image) / 'thin'
(scikitimage) or 'topostats' (original TopoStats method).
pruning_params : dict
Dictionary of pruning parameters.
filename : str
File being processed.
min_skeleton_size : int
Minimum size of grain in pixels after skeletonisation.
n_grain : int
Grain number being processed.
Returns
-------
dict
Dictionary of the contour length, whether the image is circular or linear, the end-to-end distance and an array
of coordinates.
"""
disorderedtrace = disorderedTrace(
image=cropped_image,
mask=cropped_mask,
filename=filename,
pixel_to_nm_scaling=pixel_to_nm_scaling,
min_skeleton_size=min_skeleton_size,
mask_smoothing_params=mask_smoothing_params,
skeletonisation_params=skeletonisation_params,
pruning_params=pruning_params,
n_grain=n_grain,
)
disorderedtrace.trace_dna()
if disorderedtrace.disordered_trace is None:
return None
return {
"original_image": cropped_image,
"original_grain": cropped_mask,
"smoothed_grain": disorderedtrace.smoothed_mask,
"skeleton": disorderedtrace.skeleton,
"pruned_skeleton": disorderedtrace.pruned_skeleton,
"branch_types": get_skan_image(cropped_image, disorderedtrace.pruned_skeleton, "branch_type"),
"branch_indexes": get_skan_image(cropped_image, disorderedtrace.pruned_skeleton, "node_id_src"),
}
[docs]
def get_skan_image(original_image: npt.NDArray, pruned_skeleton: npt.NDArray, skan_column: str) -> npt.NDArray:
"""
Label each branch with it's Skan branch type label.
Branch types (+1 compared to Skan docs) are defined as:
1 = Endpoint-to-endpoint (isolated branch)
2 = Junction-to-endpoint
3 = Junction-to-junction
4 = Isolated cycle
Parameters
----------
original_image : npt.NDArray
Height image from which the pruned skeleton is derived from.
pruned_skeleton : npt.NDArray
Single pixel thick skeleton mask.
skan_column : str
A column from Skan's summarize function to colour the branch segments with.
Returns
-------
npt.NDArray
2D array where the background is 0, and skeleton branches label as their Skan branch type.
"""
branch_field_image = np.zeros_like(original_image)
skeleton_image = np.where(pruned_skeleton == 1, original_image, 0)
try:
skan_skeleton = skan.Skeleton(skeleton_image, spacing=1e-9, value_is_height=True)
res = skan.summarize(skan_skeleton, separator="_")
for i, branch_field in enumerate(res[skan_column]):
path_coords = skan_skeleton.path_coordinates(i)
if skan_column == "node_id_src":
branch_field = i
branch_field_image[path_coords[:, 0], path_coords[:, 1]] = branch_field + 1
except ValueError: # when no skeleton to skan
LOGGER.warning("Skeleton has been pruned out of existence.")
return branch_field_image
[docs]
def crop_array(array: npt.NDArray, bounding_box: tuple, pad_width: int = 0) -> npt.NDArray:
"""
Crop an array.
Ideally we pad the array that is being cropped so that we have heights outside of the grains bounding box. However,
in some cases, if a grain is near the edge of the image scan this results in requesting indexes outside of the
existing image. In which case we get as much of the image padded as possible.
Parameters
----------
array : npt.NDArray
2D Numpy array to be cropped.
bounding_box : Tuple
Tuple of coordinates to crop, should be of form (min_row, min_col, max_row, max_col).
pad_width : int
Padding to apply to bounding box.
Returns
-------
npt.NDArray()
Cropped array.
"""
bounding_box = list(bounding_box)
bounding_box = pad_bounding_box(array.shape, bounding_box, pad_width)
return array[
bounding_box[0] : bounding_box[2],
bounding_box[1] : bounding_box[3],
]
[docs]
def pad_bounding_box(array_shape: tuple, bounding_box: list, pad_width: int) -> list:
"""
Pad coordinates, if they extend beyond image boundaries stop at boundary.
Parameters
----------
array_shape : tuple
Shape of original image (row, columns).
bounding_box : list
List of coordinates 'min_row', 'min_col', 'max_row', 'max_col'.
pad_width : int
Cells to pad arrays by.
Returns
-------
list
List of padded coordinates.
"""
# Top Row : Make this the first column if too close
bounding_box[0] = 0 if bounding_box[0] - pad_width < 0 else bounding_box[0] - pad_width
# Left Column : Make this the first column if too close
bounding_box[1] = 0 if bounding_box[1] - pad_width < 0 else bounding_box[1] - pad_width
# Bottom Row : Make this the last row if too close
bounding_box[2] = array_shape[0] if bounding_box[2] + pad_width > array_shape[0] else bounding_box[2] + pad_width
# Right Column : Make this the last column if too close
bounding_box[3] = array_shape[1] if bounding_box[3] + pad_width > array_shape[1] else bounding_box[3] + pad_width
return bounding_box
# 2023-06-09 - Code that runs dnatracing in parallel across grains, left deliberately for use when we remodularise the
# entry-points/workflow. Will require that the gaussian filtered array is saved and passed in along with
# the labelled regions. @ns-rse
#
#
# if __name__ == "__main__":
# cropped_images, cropped_masks = prep_arrays(image, grains_mask, pad_width)
# n_grains = len(cropped_images)
# LOGGER.info(f"[{filename}] : Calculating statistics for {n_grains} grains.")
# # Process in parallel
# with Pool(processes=cores) as pool:
# results = {}
# with tqdm(total=n_grains) as pbar:
# x = 0
# for result in pool.starmap(
# trace_grain,
# zip(
# cropped_images,
# cropped_masks,
# repeat(pixel_to_nm_scaling),
# repeat(filename),
# repeat(min_skeleton_size),
# repeat(skeletonisation_method),
# ),
# ):
# LOGGER.info(f"[{filename}] : Traced grain {x + 1} of {n_grains}")
# results[x] = result
# x += 1
# pbar.update()
# try:
# results = pd.DataFrame.from_dict(results, orient="index")
# results.index.name = "molecule_number"
# except ValueError as error:
# LOGGER.error("No grains found in any images, consider adjusting your thresholds.")
# LOGGER.error(error)
# return results