Source code for topostats.unet_masking

"""Segment grains using a U-Net model."""

from __future__ import annotations

import logging

import keras
import numpy as np
import numpy.typing as npt
import tensorflow as tf
from PIL import Image

from topostats.logs.logs import LOGGER_NAME

LOGGER = logging.getLogger(LOGGER_NAME)

# pylint: disable=too-many-locals


[docs] def dice_loss(y_true: npt.NDArray[np.float32], y_pred: npt.NDArray[np.float32], smooth: float = 1e-5) -> tf.Tensor: """ DICE loss function. Expects y_true and y_pred to be of shape (batch_size, height, width, 1). Parameters ---------- y_true : npt.NDArray[np.float32] True values. y_pred : npt.NDArray[np.float32] Predicted values. smooth : float Smoothing factor to prevent division by zero. Returns ------- tf.Tensor The DICE loss. """ # Ensure the tensors are of the same shape y_true = tf.squeeze(y_true, axis=-1) if y_true.shape[-1] == 1 else y_true y_pred = tf.squeeze(y_pred, axis=-1) if y_pred.shape[-1] == 1 else y_pred # Ensure floats not bool y_true = tf.cast(y_true, tf.float32) y_pred = tf.cast(y_pred, tf.float32) intersection = tf.reduce_sum(y_true * y_pred, axis=(1, 2)) sum_of_squares_pred = tf.reduce_sum(tf.square(y_pred), axis=(1, 2)) sum_of_squares_true = tf.reduce_sum(tf.square(y_true), axis=(1, 2)) return 1 - (2 * intersection + smooth) / (sum_of_squares_pred + sum_of_squares_true + smooth)
# IoU Loss
[docs] def iou_loss(y_true: npt.NDArray[np.float32], y_pred: npt.NDArray[np.float32], smooth: float = 1e-5) -> tf.Tensor: """ Intersection over Union loss function. Expects y_true and y_pred to be of shape (batch_size, height, width, 1). Parameters ---------- y_true : npt.NDArray[np.float32] True values. y_pred : npt.NDArray[np.float32] Predicted values. smooth : float Smoothing factor to prevent division by zero. Returns ------- tf.Tensor The IoU loss. """ # Ensure the tensors are of the same shape y_true = tf.squeeze(y_true, axis=-1) if y_true.shape[-1] == 1 else y_true y_pred = tf.squeeze(y_pred, axis=-1) if y_pred.shape[-1] == 1 else y_pred # Ensure floats not bool y_true = tf.cast(y_true, tf.float32) y_pred = tf.cast(y_pred, tf.float32) intersection = tf.reduce_sum(y_true * y_pred, axis=(1, 2)) sum_of_squares_pred = tf.reduce_sum(tf.square(y_pred), axis=(1, 2)) sum_of_squares_true = tf.reduce_sum(tf.square(y_true), axis=(1, 2)) return 1 - (intersection + smooth) / (sum_of_squares_pred + sum_of_squares_true - intersection + smooth)
[docs] def mean_iou(y_true: npt.NDArray[np.float32], y_pred: npt.NDArray[np.float32]): """ Mean Intersection Over Union metric, ignoring the background class. Parameters ---------- y_true : npt.NDArray[np.float32] True values. y_pred : npt.NDArray[np.float32] Predicted values. Returns ------- tf.Tensor The mean IoU. """ # Ensure the tensors are of the same shape, and ignore the background class # The [-1] here is to flatten the tensor into a 1D array, allowing for the calculation of the IoU # The 1: is to use all channels except channel 0. Since this would be the background class and if # we include it, the IoU would be very low since it is going to be mostly zero and so highly accurate? y_true_f = tf.reshape(y_true[:, :, :, 1:], [-1]) # ignore background class y_pred_f = tf.round(tf.reshape(y_pred[:, :, :, 1:], [-1])) # ignore background class # Calculate the IoU, using all channels except the background class intersect = tf.reduce_sum(y_true_f * y_pred_f) union = tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) - intersect smooth = tf.ones(tf.shape(intersect)) # Smoothing factor to prevent division by zero return tf.reduce_mean((intersect + smooth) / (union - intersect + smooth))
[docs] def predict_unet( image: npt.NDArray[np.float32], model: keras.Model, confidence: float, model_input_shape: tuple[int | None, int, int, int], upper_norm_bound: float, lower_norm_bound: float, ) -> npt.NDArray[np.bool_]: """ Predict cats segmentation from a flattened image. Parameters ---------- image : npt.NDArray[np.float32] The image to predict the mask for. model : keras.Model The U-Net model. confidence : float The confidence threshold for the mask. model_input_shape : tuple[int | None, int, int, int] The shape of the model input, including the batch and channel dimensions. upper_norm_bound : float The upper bound for normalising the image. lower_norm_bound : float The lower bound for normalising the image. Returns ------- npt.NDArray[np.bool_] The predicted mask. """ # Strip the batch dimension from the model input shape image_shape: tuple[int, int] = model_input_shape[1:3] LOGGER.info(f"Model input shape: {model_input_shape}") # Make a copy of the original image original_image = image.copy() # Run the model on a single image LOGGER.info("Preprocessing image for Unet prediction...") # Normalise the image image = np.clip(image, lower_norm_bound, upper_norm_bound) image = image - lower_norm_bound image = image / (upper_norm_bound - lower_norm_bound) # Resize the image to the model input shape image_resized = Image.fromarray(image) image_resized = image_resized.resize(image_shape) image_resized_np: npt.NDArray[np.float32] = np.array(image_resized) # Predict the mask LOGGER.info("Running Unet & predicting mask") prediction: npt.NDArray[np.float32] = model.predict(np.expand_dims(image_resized_np, axis=(0, 3))) LOGGER.info(f"Unet finished predicted mask. Prediction shape: {prediction.shape}") # Threshold the predicted mask predicted_mask: npt.NDArray[np.bool_] = prediction > confidence # Remove the batch dimension since we are predicting single images at a time predicted_mask = predicted_mask[0] # Note that this predicted mask can have any number of channels, depending on the number of classes for the model # Check if the output is a single channel mask and convert it to a two-channel mask since the output is # designed to be categorical, where even the background has a channel if predicted_mask.shape[2] == 1: predicted_mask = np.concatenate((1 - predicted_mask, predicted_mask), axis=2) assert len(predicted_mask.shape) == 3, f"Predicted mask shape is not 3D: {predicted_mask.shape}" assert ( predicted_mask.shape[2] >= 2 ), f"Predicted mask has less than 2 channels: {predicted_mask.shape[2]}, needs separate background channel" # Resize each channel of the predicted mask to the original image size, also remove the batch # dimension but keep channel resized_predicted_mask: npt.NDArray[np.bool_] = np.zeros( (original_image.shape[0], original_image.shape[1], predicted_mask.shape[2]) ).astype(bool) for channel_index in range(predicted_mask.shape[2]): # Note that uint8 is required to allow PIL to load the array into an image channel_mask = predicted_mask[:, :, channel_index].astype(np.uint8) channel_mask_PIL = Image.fromarray(channel_mask) # Resize the channel mask to the original image size, but we want boolean so use nearest neighbour channel_mask_PIL = channel_mask_PIL.resize( (original_image.shape[1], original_image.shape[0]), Image.Resampling.NEAREST ) resized_predicted_mask[:, :, channel_index] = np.array(channel_mask_PIL).astype(bool) return resized_predicted_mask
[docs] def make_bounding_box_square( crop_min_row: int, crop_min_col: int, crop_max_row: int, crop_max_col: int, image_shape: tuple[int, int] ) -> tuple[int, int, int, int]: """ Make a bounding box square. Parameters ---------- crop_min_row : int The minimum row index of the crop. crop_min_col : int The minimum column index of the crop. crop_max_row : int The maximum EXCLUSIVE row index of the crop. crop_max_col : int The maximum EXCLUSIVE column index of the crop. image_shape : tuple[int, int] The shape of the image. Returns ------- tuple[int, int, int, int] The new crop indices. Notes ----- The crop indices are inclusive on the minimum side and exclusive on the maximum side. So an object like this: 0 0 0 0 0 0 1 1 1 0 0 1 1 1 0 0 1 1 1 0 0 0 0 0 0 With no padding, would have a bounding box of (1, 1, 4, 4) and not (1, 1, 3, 3). Hence the maximum index for the bbox is allowed to exceed the maximum index of the image by 1. """ crop_height = crop_max_row - crop_min_row crop_width = crop_max_col - crop_min_col axes_diff: int new_crop_min_row: int new_crop_min_col: int new_crop_max_row: int new_crop_max_col: int if crop_height > crop_width: # The crop is taller than it is wide axes_diff = crop_height - crop_width # Check if we can expand equally in each direction # Note that we allow the crop max col to be equal to the image shape since the max is exclusive if crop_min_col - axes_diff // 2 >= 0 and crop_max_col + axes_diff - axes_diff // 2 <= image_shape[1]: new_crop_min_col = crop_min_col - axes_diff // 2 new_crop_max_col = crop_max_col + axes_diff - axes_diff // 2 # If we can't expand uniformly, expand as much as possible in that direction else: # Set crop edge be at the edge of the image if crop_min_col - axes_diff // 2 <= 0: new_crop_min_col = 0 # Remaining size is equal to the axes_diff minus the size we just added. # The size we added is simply the original crop_min_col since the new crop_min_col is 0. # Add that remaining size to the crop_max_col new_crop_max_col = crop_max_col + (axes_diff - crop_min_col) # Crop expansion beyond image size else: new_crop_max_col = image_shape[1] # Remaining size is equal to the axes_diff minus the size we just added. # The size we added is the size of the image minus the crop_max_col. Note that since # the max is exclusive, we use the size of the image, not the maximum image column. # The remaining size will be the axes diff minus the difference between original and new # crop_max_col. # So we subtract the remaining size from the crop_min_col to get the new crop_min_col. new_crop_min_col = crop_min_col - (axes_diff - (new_crop_max_col - crop_max_col)) # Set the new crop height to the original crop height since we are just updating the width new_crop_min_row = crop_min_row new_crop_max_row = crop_max_row else: # The crop is wider than it is tall axes_diff = crop_width - crop_height # Check if we can expand equally in each direction # Note that we allow the crop max row to be equal to the image shape since the max is exclusive if crop_min_row - axes_diff // 2 >= 0 and crop_max_row + axes_diff - axes_diff // 2 <= image_shape[0]: new_crop_min_row = crop_min_row - axes_diff // 2 new_crop_max_row = crop_max_row + axes_diff - axes_diff // 2 # If we can't expand uniformly, expand as much as possible in that dir else: # Set crop edge be at the edge of the image if crop_min_row - axes_diff // 2 <= 0: new_crop_min_row = 0 # Remaining size is equal to the axes_diff minus the size we just added. # The size we added is simply the original crop_min_row since the new crop_min_row is 0. # Add that remaining size to the crop_max_row new_crop_max_row = crop_max_row + (axes_diff - crop_min_row) # Crop expansion beyond image size else: new_crop_max_row = image_shape[0] # Remaining size is equal to the axes_diff minus the size we just added. # The size we added is the size of the image minus the crop_max_row. Note that since # the max is exclusive, we use the size of the image, not the maximum image row. # The remaining size will be the axes diff minus the difference between original and new # crop_max_row. new_crop_min_row = crop_min_row - (axes_diff - (new_crop_max_row - crop_max_row)) # Set the new crop width to the original crop width since we are just updating the height new_crop_min_col = crop_min_col new_crop_max_col = crop_max_col return new_crop_min_row, new_crop_min_col, new_crop_max_row, new_crop_max_col
[docs] def pad_bounding_box( crop_min_row: int, crop_min_col: int, crop_max_row: int, crop_max_col: int, image_shape: tuple[int, int], padding: int, ) -> tuple[int, int, int, int]: """ Pad a bounding box. Parameters ---------- crop_min_row : int The minimum row index of the crop. crop_min_col : int The minimum column index of the crop. crop_max_row : int The maximum row index of the crop. crop_max_col : int The maximum column index of the crop. image_shape : tuple[int, int] The shape of the image. padding : int The padding to apply to the bounding box. Returns ------- tuple[int, int, int, int] The new crop indices. """ new_crop_min_row: int = max(0, crop_min_row - padding) new_crop_min_col: int = max(0, crop_min_col - padding) new_crop_max_row: int = min(image_shape[0], crop_max_row + padding) new_crop_max_col: int = min(image_shape[1], crop_max_col + padding) return new_crop_min_row, new_crop_min_col, new_crop_max_row, new_crop_max_col
[docs] def pad_crop( crop: npt.NDArray, bbox: tuple[int, int, int, int], image_shape: tuple[int, int], padding: int, ) -> np.NDArray: """ Pad a crop. Parameters ---------- crop : npt.NDArray The crop to pad. bbox : tuple[int, int, int, int] The bounding box of the crop. image_shape : tuple[int, int] The shape of the image. padding : int The padding to apply to the crop. Returns ------- np.NDArray The padded crop. """ new_bounding_box = pad_bounding_box( crop_min_row=bbox[0], crop_min_col=bbox[1], crop_max_row=bbox[2], crop_max_col=bbox[3], image_shape=image_shape, padding=padding, ) # Pad the crop with the new bounding box padded_crop = np.zeros((new_bounding_box[2] - new_bounding_box[0], new_bounding_box[3] - new_bounding_box[1])) padded_crop[ bbox[0] - new_bounding_box[0] : bbox[2] - new_bounding_box[0], bbox[1] - new_bounding_box[1] : bbox[3] - new_bounding_box[1], ] = crop return padded_crop, new_bounding_box
[docs] def make_crop_square( crop: npt.NDArray, bbox: tuple[int, int, int, int], image_shape: tuple[int, int], ) -> np.NDArray: """ Make a crop square. Parameters ---------- crop : npt.NDArray The crop to make square. bbox : tuple[int, int, int, int] The bounding box of the crop. image_shape : tuple[int, int] The shape of the image. Returns ------- np.NDArray The square crop. """ new_bounding_box = make_bounding_box_square( crop_min_row=bbox[0], crop_min_col=bbox[1], crop_max_row=bbox[2], crop_max_col=bbox[3], image_shape=image_shape, ) # Make the crop square square_crop = np.zeros((new_bounding_box[2] - new_bounding_box[0], new_bounding_box[3] - new_bounding_box[1])) square_crop[ bbox[0] - new_bounding_box[0] : bbox[2] - new_bounding_box[0], bbox[1] - new_bounding_box[1] : bbox[3] - new_bounding_box[1], ] = crop return square_crop, new_bounding_box