"""Utilities."""
from __future__ import annotations
import logging
from argparse import Namespace
from collections import defaultdict
from pathlib import Path
from pprint import pformat
import numpy as np
import numpy.typing as npt
import pandas as pd
from topostats.logs.logs import LOGGER_NAME
from topostats.thresholds import threshold
LOGGER = logging.getLogger(LOGGER_NAME)
ALL_STATISTICS_COLUMNS = (
"image",
"basename",
"molecule_number",
"area",
"area_cartesian_bbox",
"aspect_ratio",
"bending_angle",
"centre_x",
"centre_y",
"circular",
"contour_length",
"end_to_end_distance",
"height_max",
"height_mean",
"height_median",
"height_min",
"max_feret",
"min_feret",
"radius_max",
"radius_mean",
"radius_median",
"radius_min",
"smallest_bounding_area",
"smallest_bounding_length",
"smallest_bounding_width",
"threshold",
"volume",
)
[docs]
def convert_path(path: str | Path) -> Path:
"""
Ensure path is Path object.
Parameters
----------
path : str | Path
Path to be converted.
Returns
-------
Path
Pathlib object of path.
"""
return Path().cwd() if path == "./" else Path(path).expanduser()
[docs]
def update_config(config: dict, args: dict | Namespace) -> dict:
"""
Update the configuration with any arguments.
Parameters
----------
config : dict
Dictionary of configuration (typically read from YAML file specified with '-c/--config <filename>').
args : Namespace
Command line arguments.
Returns
-------
dict
Dictionary updated with command arguments.
"""
args = vars(args) if isinstance(args, Namespace) else args
config_keys = config.keys()
for arg_key, arg_value in args.items():
if isinstance(arg_value, dict):
update_config(config, arg_value)
else:
if arg_key in config_keys and arg_value is not None:
original_value = config[arg_key]
config[arg_key] = arg_value
LOGGER.info(f"Updated config config[{arg_key}] : {original_value} > {arg_value} ")
if "base_dir" in config.keys():
config["base_dir"] = convert_path(config["base_dir"])
if "output_dir" in config.keys():
config["output_dir"] = convert_path(config["output_dir"])
return config
[docs]
def update_plotting_config(plotting_config: dict) -> dict:
"""
Update the plotting config for each of the plots in plot_dict.
Ensures that each entry has all the plotting configuration values that are needed.
Parameters
----------
plotting_config : dict
Plotting configuration to be updated.
Returns
-------
dict
Updated plotting configuration.
"""
main_config = plotting_config.copy()
for opt in ["plot_dict", "run"]:
main_config.pop(opt)
LOGGER.debug(
f"Main plotting options that need updating/adding to plotting dict :\n{pformat(main_config, indent=4)}"
)
for image, options in plotting_config["plot_dict"].items():
LOGGER.debug(f"Dictionary for image : {image}")
LOGGER.debug(f"{pformat(options, indent=4)}")
# First update options with values that exist in main_config
plotting_config["plot_dict"][image] = update_config(options, main_config)
LOGGER.debug(f"Updated values :\n{pformat(plotting_config['plot_dict'][image])}")
# Then combine the remaining key/values we need from main_config that don't already exist
for key_main, value_main in main_config.items():
if key_main not in plotting_config["plot_dict"][image]:
plotting_config["plot_dict"][image][key_main] = value_main
LOGGER.debug(f"After adding missing configuration options :\n{pformat(plotting_config['plot_dict'][image])}")
# Make it so that binary images do not have the user-defined z-scale
# applied, but non-binary images do.
if plotting_config["plot_dict"][image]["image_type"] == "binary":
plotting_config["plot_dict"][image]["zrange"] = [None, None]
return plotting_config
[docs]
def _get_mask(image: npt.NDArray, thresh: float, threshold_direction: str, img_name: str = None) -> npt.NDArray:
"""
Calculate a mask for pixels that exceed the threshold.
Parameters
----------
image : np.array
Numpy array representing image.
thresh : float
A float representing the threshold.
threshold_direction : str
A string representing the direction that should be thresholded. ("above", "below").
img_name : str
Name of image being processed.
Returns
-------
npt.NDArray
Numpy array of image with objects coloured.
"""
if threshold_direction == "above":
LOGGER.info(f"[{img_name}] : Masking (above) Threshold: {thresh}")
return image > thresh
LOGGER.info(f"[{img_name}] : Masking (below) Threshold: {thresh}")
return image < thresh
# LOGGER.fatal(f"[{img_name}] : Threshold direction invalid: {threshold_direction}")
[docs]
def get_mask(image: npt.NDArray, thresholds: dict, img_name: str = None) -> npt.NDArray:
"""
Mask data that should not be included in flattening.
Parameters
----------
image : npt.NDArray
2D Numpy array of the image to have a mask derived for.
thresholds : dict
Dictionary of thresholds, at a bare minimum must have key 'below' with an associated value, second key is
to have an 'above' threshold.
img_name : str
Image name that is being masked.
Returns
-------
npt.NDArray
2D Numpy boolean array of points to mask.
"""
# Both thresholds are applicable
if "below" in thresholds and "above" in thresholds:
mask_above = _get_mask(image, thresh=thresholds["above"], threshold_direction="above", img_name=img_name)
mask_below = _get_mask(image, thresh=thresholds["below"], threshold_direction="below", img_name=img_name)
# Masks are combined to remove both the extreme high and extreme low data points.
return mask_above + mask_below
# Only below threshold is applicable
if "below" in thresholds:
return _get_mask(image, thresh=thresholds["below"], threshold_direction="below", img_name=img_name)
# Only above threshold is applicable
return _get_mask(image, thresh=thresholds["above"], threshold_direction="above", img_name=img_name)
# pylint: disable=unused-argument
[docs]
def get_thresholds( # noqa: C901
image: npt.NDArray,
threshold_method: str,
otsu_threshold_multiplier: float = None,
threshold_std_dev: dict = None,
absolute: dict = None,
**kwargs,
) -> dict:
"""
Obtain thresholds for masking data points.
Parameters
----------
image : npt.NDArray
2D Numpy array of image to be masked.
threshold_method : str
Method for thresholding, 'otsu', 'std_dev' or 'absolute' are valid options.
otsu_threshold_multiplier : float
Scaling value for Otsu threshold.
threshold_std_dev : dict
Dict of above and below thresholds for the standard deviation method.
absolute : tuple
Dict of below and above thresholds.
**kwargs :
Dictionary passed to 'topostats.threshold(**kwargs)'.
Returns
-------
Dict
Dictionary of thresholds, contains keys 'below' and optionally 'above'.
"""
thresholds = defaultdict()
if threshold_method == "otsu":
thresholds["above"] = threshold(image, method="otsu", otsu_threshold_multiplier=otsu_threshold_multiplier)
elif threshold_method == "std_dev":
try:
if threshold_std_dev["below"] is not None:
thresholds["below"] = threshold(image, method="mean") - threshold_std_dev["below"] * np.nanstd(image)
if threshold_std_dev["above"] is not None:
thresholds["above"] = threshold(image, method="mean") + threshold_std_dev["above"] * np.nanstd(image)
except TypeError as typeerror:
raise typeerror
elif threshold_method == "absolute":
if absolute["below"] is not None:
thresholds["below"] = absolute["below"]
if absolute["above"] is not None:
thresholds["above"] = absolute["above"]
else:
if not isinstance(threshold_method, str):
raise TypeError(
f"threshold_method ({threshold_method}) should be a string. Valid values : 'otsu' 'std_dev' 'absolute'"
)
if threshold_method not in ["otsu", "std_dev", "absolute"]:
raise ValueError(
f"threshold_method ({threshold_method}) is invalid. Valid values : 'otsu' 'std_dev' 'absolute'"
)
return thresholds
[docs]
def create_empty_dataframe(columns: set = ALL_STATISTICS_COLUMNS, index: str = "molecule_number") -> pd.DataFrame:
"""
Create an empty data frame for returning when no results are found.
Parameters
----------
columns : list
Columns of the empty dataframe.
index : str
Column to set as index of empty dataframe.
Returns
-------
pd.DataFrame
Empty Pandas DataFrame.
"""
empty_df = pd.DataFrame(columns=columns)
return empty_df.set_index(index)
[docs]
def bound_padded_coordinates_to_image(coordinates: npt.NDArray, padding: int, image_shape: tuple) -> tuple:
"""
Ensure the padding of coordinates points does not fall outside of the image shape.
This function is primarily used in the dnaTrace.get_fitted_traces() method which aims to adjust the points of a
skeleton to sit on the highest points of a traced molecule. In order to do so it takes the ordered skeleton, which
may not lie on the highest points as it is generated from a binary mask that is unaware of the heights, and then
defines a padded boundary of 3nm profile perpendicular to the backbone of the DNA (which at this point is the
skeleton based on a mask). Each point along the skeleton therefore needs padding by a minimum of 2 pixels (in this
case each pixel equates to a cell in a NumPy array). If a point is within 2 pixels (i.e. 2 cells) of the border then
we can not pad beyond this region, we have to stop at the edge of the image and so the coordinates is adjusted such
that the padding will lie on the edge of the image/array.
Parameters
----------
coordinates : npt.NDArray
Coordinates of a point on the mask based skeleton.
padding : int
Number of pixels/cells to pad around the point.
image_shape : tuple
The shape of the original image from which the pixel is obtained.
Returns
-------
tuple
Returns a tuple of coordinates that ensure that when the point is padded by the noted padding width in
subsequent calculations it will not be outside of the image shape.
"""
# Calculate the maximum row and column indexes
max_row = image_shape[0] - 1
max_col = image_shape[1] - 1
row_coord, col_coord = coordinates
def check(coord: npt.NDArray, max_val: int, padding: int) -> npt.NDArray:
"""
Check coordinates are within the bounds of the padding.
Parameters
----------
coord : npt.NDArray
Coordinates (length = 2).
max_val : int
Maximum width in the dimension being checked (max_row or max_col).
padding : int
Padding used in the image.
Returns
-------
npt.NDArray
Coordinates adjusted for padding.
"""
if coord - padding < 0:
coord = padding
elif coord + padding > max_val:
coord = max_val - padding
return coord
return check(row_coord, max_row, padding), check(col_coord, max_col, padding)