"""Find grains in an image."""
# pylint: disable=no-name-in-module
from collections import defaultdict
import logging
from typing import List, Dict
import numpy as np
from skimage.segmentation import clear_border
from skimage import morphology
from skimage.measure import regionprops
from skimage.color import label2rgb
from topostats.logs.logs import LOGGER_NAME
from topostats.thresholds import threshold
from topostats.utils import _get_mask, get_thresholds
LOGGER = logging.getLogger(LOGGER_NAME)
# pylint: disable=fixme
# pylint: disable=line-too-long
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments
# pylint: disable=bare-except
# pylint: disable=dangerous-default-value
[docs]
class Grains:
"""Find grains in an image."""
def __init__(
self,
image: np.ndarray,
filename: str,
pixel_to_nm_scaling: float,
threshold_method: str = None,
otsu_threshold_multiplier: float = None,
threshold_std_dev: dict = None,
threshold_absolute: dict = None,
absolute_area_threshold: dict = {
"above": [None, None],
"below": [None, None],
},
direction: str = None,
smallest_grain_size_nm2: float = None,
):
"""Initialise the class.
Parameters
----------
image: np.ndarray
2D Numpy array of image
filename: str
File being processed
pixel_to_nm_scaling: float
Sacling of pixels to nanometre.
threshold_multiplier : Union[int, float]
Factor by which below threshold is to be scaled prior to masking.
threshold_method: str
Method for determining threshold to mask values, default is 'otsu'.
threshold_std_dev: dict
Dictionary of 'below' and 'above' factors by which standard deviation is multiplied to derive the threshold if threshold_method is 'std_dev'.
threshold_absolute: dict
Dictionary of absolute 'below' and 'above' thresholds for grain finding.
absolute_area_threshold: dict
Dictionary of above and below grain's area thresholds
direction: str
Direction for which grains are to be detected, valid values are above, below and both.
"""
self.image = image
self.filename = filename
self.pixel_to_nm_scaling = pixel_to_nm_scaling
self.threshold_method = threshold_method
self.otsu_threshold_multiplier = otsu_threshold_multiplier
self.threshold_std_dev = threshold_std_dev
self.threshold_absolute = threshold_absolute
self.absolute_area_threshold = absolute_area_threshold
# Only detect grains for the desired direction
self.direction = [direction] if direction != "both" else ["above", "below"]
self.smallest_grain_size_nm2 = smallest_grain_size_nm2
self.thresholds = None
self.images = {
"mask_grains": None,
"tidied_border": None,
"tiny_objects_removed": None,
"objects_removed": None,
# "labelled_regions": None,
# "coloured_regions": None,
}
self.directions = defaultdict()
self.minimum_grain_size = None
self.region_properties = defaultdict()
self.bounding_boxes = defaultdict()
self.grainstats = None
[docs]
def tidy_border(self, image: np.array, **kwargs) -> np.array:
"""Remove grains touching the border.
Parameters
----------
image: np.array
Numpy array representing image.
Returns
-------
np.array
Numpy array of image with borders tidied.
"""
LOGGER.info(f"[{self.filename}] : Tidying borders")
return clear_border(image, **kwargs)
[docs]
def label_regions(self, image: np.array) -> np.array:
"""Label regions.
This method is used twice, once prior to removal of small regions, and again afterwards, hence requiring an
argument of what image to label.
Parameters
----------
image: np.array
Numpy array representing image.
Returns
-------
np.array
Numpy array of image with objects coloured.
"""
LOGGER.info(f"[{self.filename}] : Labelling Regions")
return morphology.label(image, background=0)
[docs]
def calc_minimum_grain_size(self, image: np.ndarray) -> float:
"""Calculate the minimum grain size in pixels squared.
Very small objects are first removed via thresholding before calculating the below extreme.
"""
region_properties = self.get_region_properties(image)
grain_areas = np.array([grain.area for grain in region_properties])
if len(grain_areas > 0):
# Exclude small objects less than a given threshold first
grain_areas = grain_areas[
grain_areas >= threshold(grain_areas, method="otsu", otsu_threshold_multiplier=1.0)
]
self.minimum_grain_size = np.median(grain_areas) - (
1.5 * (np.quantile(grain_areas, 0.75) - np.quantile(grain_areas, 0.25))
)
else:
self.minimum_grain_size = -1
[docs]
def remove_noise(self, image: np.ndarray, **kwargs) -> np.ndarray:
"""Removes noise which are objects smaller than the 'smallest_grain_size_nm2'.
This ensures that the smallest objects ~1px are removed regardless of the size distribution of the grains.
Parameters
----------
image: np.ndarray
2D Numpy image to be cleaned.
Returns
-------
np.ndarray
2D Numpy array of image with objects < smallest_grain_size_nm2 removed.
"""
LOGGER.info(
f"[{self.filename}] : Removing noise (< {self.smallest_grain_size_nm2} nm^2"
"{self.smallest_grain_size_nm2 / (self.pixel_to_nm_scaling**2):.2f} px^2)"
)
return morphology.remove_small_objects(
image, min_size=self.smallest_grain_size_nm2 / (self.pixel_to_nm_scaling**2), **kwargs
)
[docs]
def remove_small_objects(self, image: np.array, **kwargs):
"""Remove small objects from the input image. Threshold determined by the minimum_grain_size variable of the
Grains class which is in pixels squared.
Parameters
----------
image: np.ndarray
2D Numpy image to remove small objects from.
Returns
-------
np.ndarray
2D Numpy array of image with objects < minimum_grain_size removed.
"""
# If self.minimum_grain_size is -1, then this means that
# there were no grains to calculate the minimum grian size from.
if self.minimum_grain_size != -1:
small_objects_removed = morphology.remove_small_objects(
image,
min_size=self.minimum_grain_size, # minimum_grain_size is in pixels squared
**kwargs,
)
LOGGER.info(
f"[{self.filename}] : Removed small objects (< \
{self.minimum_grain_size} px^2 / {self.minimum_grain_size / (self.pixel_to_nm_scaling)**2} nm^2)"
)
return small_objects_removed > 0.0
return image
[docs]
def area_thresholding(self, image: np.ndarray, area_thresholds: list):
"""Removes objects larger and smaller than the specified thresholds.
Parameters
----------
image: np.ndarray
Image array where the background == 0 and grains are labelled as integers > 0.
area_thresholds: list
List of area thresholds (in nanometres squared, not pixels squared), first should be
the below (smaller) threshold, second above (larger) threshold.
Returns
-------
np.ndarray
Image where grains outside the thresholds have been removed, as a re-numbered labeled image.
"""
image_cp = image.copy()
below, above = area_thresholds
# if one value is None adjust for comparison
if above is None:
above = image.size * self.pixel_to_nm_scaling**2
if below is None:
below = 0
# Get array of grain numbers (discounting zero)
uniq = np.delete(np.unique(image), 0)
grain_count = 0
LOGGER.info(
f"[{self.filename}] : Area thresholding grains | Thresholds: L: {(below / self.pixel_to_nm_scaling**2):.2f},"
f"U: {(above / self.pixel_to_nm_scaling**2):.2f} px^2, L: {below:.2f}, U: {above:.2f} nm^2."
)
for grain_no in uniq: # Calculate grian area in nm^2
grain_area = np.sum(image_cp == grain_no) * (self.pixel_to_nm_scaling**2)
# Compare area in nm^2 to area thresholds
if grain_area > above or grain_area < below:
image_cp[image_cp == grain_no] = 0
else:
grain_count += 1
image_cp[image_cp == grain_no] = grain_count
return image_cp
[docs]
def colour_regions(self, image: np.array, **kwargs) -> np.array:
"""Colour the regions.
Parameters
----------
image: np.array
Numpy array representing image.
Returns
-------
np.array
Numpy array of image with objects coloured.
"""
coloured_regions = label2rgb(image, **kwargs)
LOGGER.info(f"[{self.filename}] : Coloured regions")
return coloured_regions
[docs]
@staticmethod
def get_region_properties(image: np.array, **kwargs) -> List:
"""Extract the properties of each region.
Parameters
----------
image: np.array
Numpy array representing image
Returns
-------
List
List of region property objects.
"""
return regionprops(image, **kwargs)
[docs]
def get_bounding_boxes(self, direction) -> Dict:
"""Derive a list of bounding boxes for each region from the derived region_properties
Parameters
----------
direction: str
Direction of threshold for which bounding boxes are being calculated.
Returns
-------
dict
Dictionary of bounding boxes indexed by region area.
"""
return {region.area: region.area_bbox for region in self.region_properties[direction]}
[docs]
def find_grains(self):
"""Find grains."""
LOGGER.info(f"[{self.filename}] : Thresholding method (grains) : {self.threshold_method}")
self.thresholds = get_thresholds(
image=self.image,
threshold_method=self.threshold_method,
otsu_threshold_multiplier=self.otsu_threshold_multiplier,
threshold_std_dev=self.threshold_std_dev,
absolute=self.threshold_absolute,
)
for direction in self.direction:
LOGGER.info(f"[{self.filename}] : Finding {direction} grains, threshold: ({self.thresholds[direction]})")
self.directions[direction] = {}
self.directions[direction]["mask_grains"] = _get_mask(
self.image,
thresh=self.thresholds[direction],
threshold_direction=direction,
img_name=self.filename,
)
self.directions[direction]["labelled_regions_01"] = self.label_regions(
self.directions[direction]["mask_grains"]
)
self.directions[direction]["tidied_border"] = self.tidy_border(
self.directions[direction]["labelled_regions_01"]
)
LOGGER.info(f"[{self.filename}] : Removing noise ({direction})")
self.directions[direction]["removed_noise"] = self.area_thresholding(
self.directions[direction]["tidied_border"],
[self.smallest_grain_size_nm2, None],
)
LOGGER.info(f"[{self.filename}] : Removing small / large grains ({direction})")
# if no area thresholds specified, use otsu
if self.absolute_area_threshold[direction].count(None) == 2:
self.calc_minimum_grain_size(self.directions[direction]["removed_noise"])
self.directions[direction]["removed_small_objects"] = self.remove_small_objects(
self.directions[direction]["removed_noise"]
)
else:
self.directions[direction]["removed_small_objects"] = self.area_thresholding(
self.directions[direction]["removed_noise"],
self.absolute_area_threshold[direction],
)
self.directions[direction]["labelled_regions_02"] = self.label_regions(
self.directions[direction]["removed_small_objects"]
)
self.region_properties[direction] = self.get_region_properties(
self.directions[direction]["labelled_regions_02"]
)
LOGGER.info(f"[{self.filename}] : Region properties calculated ({direction})")
self.directions[direction]["coloured_regions"] = self.colour_regions(
self.directions[direction]["labelled_regions_02"]
)
self.bounding_boxes[direction] = self.get_bounding_boxes(direction=direction)
LOGGER.info(f"[{self.filename}] : Extracted bounding boxes ({direction})")