Calibrating the Dice Loss to Handle Neural Network Overconfidence for Biomedical Image Segmentation

In this section, we first introduce the CE loss and its variant, the Focal loss, followed by the DSC loss. We then identify the cause of the poor calibration using the DSC loss, and use this to derive the DSC++ loss. After introducing softmax thresholding, the section finally concludes with details on the experimental setup and implementation.

CE Loss

CE measures the difference between two probability distributions y and p. The CE loss is among the most widely used loss function in machine learning, and in the context of image segmentation, y and p represent the true and predicted distributions over class labels for a given pixel, respectively. The CE loss, (\(\mathcal _\text \)), is defined as:

$$\begin \mathcal _\text (y, p)=-\frac \sum _^ \sum _^ y_ \cdot \log \left( p_\right) , \end$$

(1)

where \(y_\) uses a one-hot encoding scheme corresponding to the ground truth labels, \(p_\) is a matrix of predicted values generated by the model for each class, and where indices i and c iterate over all pixels and classes, respectively. The CE loss is a strictly proper scoring rule, superficially equivalent to the NLL, and therefore yields consistent probabilistic predictions [22].

Focal Loss

The Focal loss (\(\mathcal _\text \)) is an extension of the cross entropy loss developed to address the issue of class imbalance in classification tasks [16].

The Focal loss uses a modulating factor \(\gamma\) to reduce the contribution of easy examples to the overall loss:

$$\begin \mathcal _\text =\boldsymbol\left( 1-\left( p_\right) \right) ^ \cdot \mathcal _\text , \end$$

(2)

where \(\boldsymbol\) is a vector of class weights, \(p_\) is a matrix of ground truth probabilities for each class, and \(\mathcal _\text \) is the cross entropy loss as defined in Eq. (1). The Focal loss is equivalent to the cross entropy loss when \(\gamma\) = 1.

DSC Loss

CE and the Focal loss are based on pixel-wise error, and therefore in class imbalanced situations, using the CE-based losses results in over-representation of larger objects in the loss, and consequently under-segmentation of smaller objects. Often the segmentation target in biomedical imaging tasks occupies a small area relative to the size of the image, limiting its use as a segmentation quality metric or loss function [10].

In contrast, the DSC is a spatial overlap index and is therefore robust to class imbalance, and is defined as:

$$\begin \text =\frac \sum _^ \frac^ p_ y_}^ p_+\sum _^ y_}, \end$$

(3)

where the DSC loss (\(\mathcal _\text \)) is:

$$\begin \mathcal _}=1-\text . \end$$

(4)

DSC++ Loss

The optimisation goal, for both the CE and the DSC loss, is for the neural network to produce confident, and correct, predictions matching the ground truth label. However, neural network overconfidence is a well known phenomenon associated with the DSC loss, but not with the CE loss. To understand this difference, we provide an equivalent definition of the DSC loss \(\mathcal _\text \) (Eq. (4)), in terms of true positive (TP), false negative (FN) and false positive predictions (FP):

$$\begin \mathcal _}=1-\frac} + \text + \text }, \end$$

(5)

noting that the DSC score is the harmonic mean of precision and recall, where:

$$\begin \text=\frac}+\text }, \end$$

(6)

$$\begin \text=\frac}+\text }. \end$$

(7)

When both classes are present in equal frequency, the errors associated with the FP and FN predictions are not biased towards a particular class. However, in class imbalanced situations, high precision, low recall solutions are favoured, with over-prediction of the dominant class [23]. Combined with an optimisation goal that favours confident predictions, this results in networks producing extremely confident, and often incorrect, predictions of the dominant class in regions of uncertainty.

To overcome this issue, we reformulate the DSC loss to more heavily penalise overconfident predictions. First, we define another equivalent formulation of the \(\mathcal _\text \), identical in structure to Eq. (5):

$$\begin \mathcal _\text = 1- \frac \sum _^ \frac^ p_ y_}^ p_ y_+\sum _^ p_ y_+ \sum _^ p_ y_}, \end$$

(8)

where \(p_\) is the probability of pixel i belonging to class c, and \(p_\) is the probability of pixel not belonging to class c. Similarly, \(y_\) is 1 for class c and 0 for all other classes, and conversely \(y_\) takes values of 0 for class c and 1 for all other classes.

To penalise overconfidence for uncertain regions, we apply the focal parameter, \(\gamma\), directly to the FP and FN predictions, defining the DSC++ loss (\(\mathcal _\text \)):

$$\begin \mathcal _\text = 1- \frac \sum _^ \frac^ p_ y_}^ p_ y_+\sum _^ y_)}^\gamma + \sum _^ y_)}^\gamma }. \end$$

(9)

The DSC++ loss achieves selective penalisation of the overconfident predictions by transforming the penalty from a linear to an exponentially weighted penalty. When \(\gamma = 1\), the DSC++ loss is identical to the DSC loss. When \(\gamma > 1\), overconfident predictions are more heavily penalised, with increasing values of \(\gamma\) resulting in successively larger penalties applied. Higher \(\gamma\) values therefore favour low confidence predictions. The optimal \(\gamma\) value balances the maintenance of confident, correct predictions while simultaneously suppressing confident but incorrect predictions.

Softmax Thresholding

While the softmax function is not a proxy for uncertainty, the distribution of well calibrated softmax outputs is closely related to the underlying uncertainty, even for out-of-distribution data [24, 25]. To generate a class labelled segmentation output, the argmax function assigns each pixel with the associated class based on the highest softmax value. Rather than using the argmax function, we use a variable threshold that enables manual adjustment of model outputs to favour either precision or recall. Here, we define the output of a model using an indicator function, describing a per-pixel operation that compares the softmax output for the segmentation target, s, to a given softmax threshold \(\mathcal \):

$$\begin I_= 1 & \text s<\mathcal \\ 0 & \text \end\right. }. \end$$

(10)

With this generalisation, the argmax function may be restated as a special case where \(\mathcal = 0.5\). Higher values of \(\mathcal \) favour precision, while lower values favour recall.

Dataset Descriptions and Evaluation Metrics

To evaluate our proposed loss function, we select six public, well-validated biomedical image segmentation datasets. For retinal vessel segmentation, we use the Digital Retinal Images for Vessel Extraction (DRIVE) dataset [26]. The DRIVE dataset consists of 40 coloured fundus photographs obtained from diabetic retinopathy screening in the Netherlands, with an image resolution of \(768 \times 584\) pixels. The Breast UltraSound 2017 (BUS2017) dataset consists of 163 ultrasound images of breast lesions with an average image size of \(760 \times 570\) pixels collected from the UDIAT Diagnostic Centre of the Parc Taulí Corporation, Sabadell, Spain [27]. Furthermore, we include the 2018 Data Science Bowl (2018DSB) dataset, which contains 670 light microscopy images for nuclei segmentation [28]. For skin lesion segmentation, we use the ISIC2018: Skin Lesion Analysis Towards Melanoma Detection grand challenge dataset. This dataset contains 2,594 images of skin lesions with an average size of \(2166 \times 3188\) pixels [29]. For colorectal polyp segmentation, we use the CVC-ClinicDB dataset, which consists of 612 frames containing polyps with image resolution 288 \(\times\) 368 pixels, generated from 23 video sequences from 13 different patients using standard colonoscopy interventions with white light [30]. Finally, for 3D multi-class segmentation, we use the Kidney Tumour Segmentation 2019 (KiTS19) dataset [31]. This dataset contains 300 arterial phase abdominal CT scans, with voxel-level kidney and kidney tumour annotations. We exclude the 90 scans without associated segmentation masks, and further exclude another 6 scans (case 15, 23, 37, 68, 125 and 133) due to issues with the ground truth quality [32].

For all the experiments, except for the DRIVE dataset, which is already partitioned into 20 training and 20 test images, we randomly partitioned the other five datasets into 80% development and 20% test set. For all datasets, we further partitioned the development set into 80% training set and 20% validation set. Except for the CVC-ClinicDB and KiTS19 datasets, image resolutions are downsampled using bilinear interpolation. For KiTS19, we performed on-the-fly random sampling of patch size \(80 \times 160 \times 160\), with patch-wise overlap of \(40 \times 80 \times 80\). A summary of the datasets, image resolutions and data partitions are presented in Table 1.

Table 1 Summary of the dataset details and training setup used in these experiments

To assess the loss functions, we select two evaluation metrics each for calibration and performance. For calibration, we use the NLL and Brier score, both strictly proper scoring rules. The NLL is equivalent to the CE loss in Eq. (1), while the Brier score (Brier) computes the mean squared error between predicted probability scores and the true class labels:

$$\begin \mathrm =\frac\frac \sum _^ \sum _^(y_ - p_)^. \end$$

(11)

For both metrics, a lower score corresponds to better calibration.

For performance, we use the DSC as previously defined, and the Intersection over Union (IoU), also known as the Jaccard Index:

$$\begin \mathrm =\frac}+\mathrm +\mathrm }. \end$$

(12)

Contrary to the calibration metrics, a higher DSC or Jaccard score corresponds to better performance.

Implementation Details

For our experiments, we leveraged the Medical Image Segmentation with Convolutional Neural Networks (MIScnn) open-source Python library [33]. This is based on the Keras library using the Tensorflow backend, and all experiments were carried out using NVIDIA P100 GPUs.

Images were resized as described previously and normalised per-image using the z-score. We applied on-the-fly data augmentation with probability 0.15, including scaling (0.85–\(1.25\times\)), rotation (\(-15^\) to \(+15^\)), mirroring (vertical and horizontal axes), elastic deformation (\(\alpha \in [0, 900]\) and \(\sigma \in [9.0, 13.0]\)) and brightness (0.5–\(2\times\)).

To investigate the effect of altering \(\gamma\) on the DSC++ loss, we perform a grid search, evaluating values \(\gamma \in [0.5, 5]\).

To evaluate the loss functions, we trained the standard U-Net, with model parameters initialised using the Xavier initialisation [34]. We trained each model with instance normalisation, using the stochastic gradient descent optimiser with a batch size of 1 and initial learning rate of 0.1 [35]. For convergence criteria, we used ReduceLROnPlateau to reduce the learning rate by 0.1 if the validation loss did not improve after 25 epochs, and the EarlyStopping callback to terminate training if the validation loss did not improve after 50 epochs. To compromise for the large patch size used for training on the KiTS19 dataset, we used a stricter convergence criteria of 5 epochs and 10 epochs for the ReduceLROnPlateau and EarlyStopping callbacks respectively.

To evaluate the effect of substituting the DSC loss for the DSC++ loss in several DSC-based variants commonly used to achieve state-of-the-art results, we selected the Tversky loss, Focal Tversky loss, Combo loss and Unified Focal loss [10, 23, 36, 37].

The Combo loss (\(\mathcal _}\)) is a compound loss function defined as the weighted sum of the DSC and modified CE loss (\(\mathcal _}\)) [37]:

$$\begin \mathcal _\text =\alpha \left( \mathcal _\text \right) -(1-\alpha ) \cdot }, \end$$

(13)

where:

$$\begin \mathcal _\text =-\frac \sum _^ \beta \left( y_\ln \left( p_\right) \right) +(1-\beta )\left[ \left( 1-y_\right) \ln \left( 1-p_\right) \right] . \end$$

(14)

The parameters \(\alpha\) and \(\beta\) take values in the range [0, 1], controlling the relative contribution of the DSC and CE terms to the loss, and the relative weights assigned to false positives and negatives, respectively. Optimising models with the Combo loss has been observed to improve performance, as well as produce visually more consistent segmentations over models trained using the component losses [38].

To overcome the high precision, low recall bias associated with the DSC loss, the Tversky loss (\(\mathcal _}\)) modifies the weights associated with the FP and FN predictions [23]:

$$\begin \mathcal _\text = \sum _^(1 - TI), \end$$

(15)

where the Tversky index (TI) is defined as:

$$\begin \text = \frac^ p_ y_}^ p_ y_+\alpha \sum _^ p_ y_+\beta \sum _^ p_, y_}, \end$$

(16)

where \(\alpha\) and \(\beta\) control the FP and FN weightings, respectively.

To handle class imbalanced data, the Focal Tversky loss (\(\mathcal _\text \)) applies a focal parameter \(\gamma\) to alter the weights associated with difficult to classify examples [36]:

$$\begin \mathcal _\text =\sum _^(1-\text )^}, \end$$

(17)

\(\gamma < 1\) increases the degree of focusing on harder examples.

Finally, the Unified Focal loss (\(\mathcal _\text \)) generalises distribution-based and region-based loss functions into a single framework [10], and is defined as the weighted sum of the Asymmetric Focal loss (\(\mathcal _\text \)) and Asymmetric Focal Tversky loss (\(\mathcal _\text \)):

$$\begin \mathcal _\text =\lambda \mathcal _\text +(1-\lambda ) \mathcal _\text , \end$$

(18)

where:

$$\begin \mathcal _\text =- \frac y_ \log \left( p_\right) -\frac \sum _\left( 1-p_\right) ^ \log \left( p_\right) , \end$$

(19)

$$\begin \mathcal _}=\sum _ \ne \mathrm }(1-\mathrm )+\sum _=\mathrm }(1-\mathrm )^, \end$$

(20)

where the TI is redefined as:

$$\begin } = \frac^ p_ y_}^ p_ y_+(1-\delta )\sum _^ p_ y_+\delta \sum _^ p_ y_}. \end$$

(21)

The three hyperparameters are \(\lambda\), which controls the relative weights of the two component losses, \(\delta\), which controls the relative weighting of positive and negative examples, and \(\gamma\), which controls the relative weighting of easy and difficult examples.

We used the optimal hyperparameters as described in the original papers, detailed in Table 2. For each loss function, we substituted the DSC component of the loss for the DSC++ loss, setting \(\gamma = 2\).

Table 2 Hyperparameter settings used in these experiments for the DSC and cross entropy-based loss functions

To test for statistical significance, we used the Wilcoxon rank sum test. A statistically significant difference was defined as \(p < 0.05\). We use bootstrapping to calculate the standard errors for each metric. To evaluate effect of softmax thresholding, we selected thresholds \(\mathcal \in [0.05,0.95]\) using the DSC and DSC++ loss on the DRIVE dataset.

留言 (0)

沒有登入
gif