An explainable self-attention deep neural network for detecting mild cognitive impairment using multi-input digital drawing tasks

Data collection

Under the institutional review board approval, a digital version of the MoCA test was administered on a tablet with a digital pen to a total of 918 subjects with informed consents by trained psychologists at King Chulalongkorn Memorial Hospital, Bangkok, Thailand. The population came from a healthy elderly cohort which focused on preventive care for healthy Thai citizens without major medical conditions (such as organ failures). The median age was 67 years old (ranging from 55 to 89 years old), 77% female, 44% received bachelor’s degree, and 20% received higher education. For the clock drawing task, the subjects were instructed to “draw a circular clock face with all the numbers and clock hands indicating the time of 10 min past 11 o’clock.” In the cube-copying test, the subjects were instructed to copy the Necker cube image on an empty space. In the trail-making test, the subjects were instructed to “draw a line that goes from a number to a letter in an ascending order, starting at number 1 (pointing to the number 1), to this letter (pointing to the letter A), then go to the next number (pointing to the number 2), and so on.”

For each subject, we extracted the drawn clock drawing, cube-copying, and trail-making images along with the MoCA score. We then categorized the subjects into healthy aging controls and MCI patients based on their MoCA scores. In particular, the subjects were categorized as having MCI if the MoCA scores were below 25 as typically used in clinical routines [20], resulting in 651 healthy subjects and 267 MCI patients in our dataset. The collected data were randomly split into three groups in a stratified fashion: 70% as training, 15% as validation, and 15% as test data. All images were resized to 256 × 256 in all experiments.

figure a Proposed method: multi-input Conv-Att model with soft labels

We developed a multi-input deep learning method for MCI vs healthy aging control classification that is a cascade of CNN backbones and self-attention layers [22, 23] trained with soft labels, as shown in Fig. 3. As opposed to existing models which take a clock drawing image as the only input to the models [13, 14], our proposed multi-input model takes clock drawing, cube-copying, and trail-making images simultaneously as inputs, exploiting complementary information offered by the three neuropsychological tests. Incorporating the self-attention layers into the model leads to more efficient image representations, compared to typical CNNs, that can later be used to support the model’s classification decision through heat map visualization. The soft label component of our method is designed to aid our model training by taking into account the uncertainty of the diagnostic labels (i.e., MCI vs. healthy aging control) near the designated MoCA score cutoff. An overview of the training process of the proposed method is presented in Algorithm 1. In the following subsections, we described each of the components in detail.

Fig. 3figure 3

Overview of our proposed multi-input Conv-Att model. Our model simultaneously takes clock drawing, cube-copying, and trail-making images as its inputs and processes them using a cascade of CNNs and a stack of self-attention layers

Conv-Att model architecture

As shown in Fig. 3, clock drawing, cube-copying, and trail-making images are used as inputs to our model. Each of the three images is passed into a separate CNN backbone (VGG16 [24] pretrained on the ImageNet dataset [25], followed by a stack of self-attention layers, resulting in a vectorized image representation). Including the self-attention layers in the model leads to not only efficient image representation, but also improved visual explanation for MCI vs. healthy aging control classification. Then, the resulting vectors from the three tasks are concatenated and processed by a two-node fully connected layer with the softmax function \(_i\left(\overrightarrow\right)=\frac}^}\).

Self-attention

Unlike in a standard image classification model, we employ self-attention instead of a pooling layer to aggregate the output from the CNN backbone, \(\widetilde X\in\mathbb^\).,and C are the height, width, and number of filters, respectively. First, we initialize a random classification token vector, [CLS] \(\in\mathbb^D\) where D is the hidden dimension in the self-attention mechanism used in BERT [26]. The [CLS] vector is used to aggregate visual representation from all pixels in X. Second, 1 × 1 convolution with D output filters are applied to X to adjust its last dimension to match the hidden dimension D of the [CLS] vector. After that, it is reshaped into a matrix of shape HL × D and concatenated with [CLS] resulting in \(\overset\in }^\). Reshaping the data this way enables us to investigate how much each pixel contributes to the final classification decision through the attention rollout method [27]. The self-attention is defined as:

$$\mathrm\left(Q,K,V\right)=\mathrm\left(\frac^T}}\right)V$$

where Q, K, V \(\in\mathbb\)D × D are the query, key, and value, respectively, which are the projections of \(\overset\in }^\) with different linear functions: \(Q=\overset_Q^T,K=\overset_K^T,\) and \(V=\overset_V^T\) where WQ, WK, WV  \(\in\mathbb\)D × D. At the final layer of self-attention, the vector at the [CLS] position is used as the final image representation.

Soft-label

As explained in the “Data collection” section, we assigned the label of 0 (healthy control) to a subject with the MoCA score higher than or equal to 25, and the label of 1 (MCI) otherwise. Such a labeling approach is typically referred to as hard labeling. While training a model with hard labels is the most commonly used approach to solving binary classification, we propose to train our proposed model with soft labels based on MoCA scores for MCI vs. healthy aging control classification to take into account the uncertainty of the diagnostic labels (MCI vs. healthy aging control) near the MoCA score cutoff of 25. Specifically, we define a soft label y according to the following equation:

$$y=1-\sigma \left(m-24.5\right)$$

where m is a MoCA score, and σ denotes the sigmoid function. Since a subject with the MoCA score of 24 is labeled as an MCI patient and a subject with the MoCA score of 25 is labeled as a healthy control, we subtract 24.5 from m so that the center of the sigmoid will be at 24.5.

The hard threshold of 25, below which a subject is considered an MCI patient, is a man-made criterion, rather than the number revealed through rigorous statistical tests from a large number of trials and can be varied depending on contexts such as education or cultures [20, 28, 29]. Rather, by assigning a soft label, the uncertainty in the classification result is manifested through the sigmoidal probability function. So, in a post hoc way, the soft label approach can help relax the strong classification bias inherent in the hard label approach. We trained the proposed model by minimizing the binary cross-entropy loss:

$$L=-\frac_^M\left(_i\log _i+\left(1-_i\right)\log \left(1-_i\right)\right)$$

where M is the number of training data, yi is the soft label of the data i, and pi is the output of the model which can be interpreted as the predicted probability that the data i is associated with MCI.

Attention rollout

To visualize how self-attention combines the pixels of the last feature maps calculated by the CNN backbone X into the final image representation, we used attention rollout [27]. In the self-attention layers, there exist residual connections between consecutive layers. Therefore, the output of the self-attention layer l + 1 is defined below:

where Watt is the attention weight, and Vl is the output of the self-attention layer l. To compensate for the effects of the residual connections, the raw attention A is A = 0.5Watt + 0.5I where I is the identity matrix. The attentions from the self-attention layer li to layer lj are computed by recursively multiplying the attention weights as follows:

$$\widetilde A\left(l_i\right)=\left\A\left(l_i\right)\widetilde A\left(l_\right),&i>j\\\;\;\;\;\;\;\;\;\;\;\;\;A\left(l_i\right),&i=j\end\right.$$

where \(\overset\left(_i\right)\) is the attention rollout at the self-attention layer li, and A(li) is the raw attention at the self-attention layer li. The interpretability from our model is how the self-attention layers combine the last feature maps into the final image representation. Therefore, it is equivalent to the attention rollout for [CLS] over all the pixels of the last feature map. The attention rollout for each [CLS] is reshaped back to the size of the last feature map and then resized to match the size of the original input image. The heat map from each [CLS] is used as the interpretability for each input image.

Experiments

We compared our proposed method to four VGG16-based models: single-input VGG16 models that take only an image from one of the three tasks (i.e., clock drawing, cube-copying, and trail-making) as input and a multi-input VGG16 model that simultaneously takes clock drawing, cube-copying, and trail-making images as inputs. For the multi-input version, different VGG16s were used to process different input images. At the end of each VGG16, the global average pooling layer was applied. The average pooled image features from each task were concatenated and then passed into a two-node fully connected layer with the softmax function.

We also compared the results of the proposed method to those of the proposed method with some components removed. In particular, we recorded the performances of the single-input Conv-Att models and the multi-input Conv-Att model, both trained with hard labels.

Model training

Adam [30] optimizer with the learning rate of 1e−5, β1 = 0.9, β2 = 0.99, and ϵ = 10−7 were used in all experiments. The models were trained for 100 epochs with a batch size of 64. We adopted image augmentation to increase the effective size of the training data. Specifically, we first zero-padded the image to a size of 280 × 280 and then cropped the image back to 256 × 256 with the center at a random location in the padded image. For the models that included stacked self-attention layers, we used self-attention layers with the number of heads of one, hidden dimension size of 128, and hidden dimension size of the feedforward layer of 512.

Evaluation

We performed 5 random training-validation-test data splittings and reported the mean and standard deviation of the classification accuracies and F1-scores obtained from each method. Since all the methods were trained to predict the probability of having MCI, p, for each input, we needed to convert the model prediction into a diagnostic label (i.e., hard label) so that we could compute the accuracy and F1-score meaningfully. In this case, we categorized all the subjects with p ≥ 0.5 as MCI patients and p < 0.5 as healthy controls. We also reported the AUCs for all the models under consideration.

We also assessed the ability of the proposed method to provide a visual explanation to support its diagnostic decision by comparing the heat maps generated by the proposed model to those generated by the multi-input VGG16 model with Grad-CAM [19], which is one of the most commonly used methods for visual explanation, based on two metrics: (1) the interpretability scores given by 3 experts and (2) the IoU between the heat maps obtained from each method and the corresponding ground truth ROIs.

For each subject, we obtained the heat maps from both the proposed method and the multi-input VGG16 model. Then, we displayed them side-by-side and separately asked 1 neurologist and 2 licensed neuropsychologists to give scores between 1 and 5 to each set of heat maps (1 being the worst and 5 being the best in terms of providing a visual explanation that aligned with their experience and knowledge). To avoid potential bias, we randomly shuffled the display locations (left vs. right) in a way that the heat maps of the proposed method were displayed on the left or the right of the VGG16 model with equal probability. We ensured that all the evaluators had sufficient clinical experience in testing and evaluating these drawing tasks while still allowing them to rate the interpretability of the heat map results using their own judgments. The rationale is that the interpretability of the heat map can be evaluated in more than one perspective, and our proposed model should generally perform better than the Grad-CAM model across multiple perspectives.

In addition to the rating provided by the experts, the interpretability of the heat maps was assessed based on the IoUs between the heat maps and ground truth ROIs, where the IoU between two arbitrary shapes A and B is computed as \(IoU\left(A,B\right)=\frac\). Prior to computing IoUs, each heat map was converted into a binary matrix by assigning 1 to the top k percent of all pixels with the highest values and setting the remaining pixels to 0. As shown in Fig. 1, two types of ROIs were used as the ground truth ROIs in our evaluations: whole-drawing ROIs and expert ROIs.

For the whole-drawing ROIs, the goal was to check if the heat maps highlighted the regions in the vicinity of the regions drawn by the human subjects. For each image, we defined an initial ROI as the smallest region with a simple shape (e.g., an ellipse, a circle, and a polygon) that enclosed all the non-zero pixels in the image. Then, we enlarged the resulting ROI by a few pixels and used it as the ground truth ROI. Under this ROI type, the heat maps that do not highlight regions far away from the regions drawn will achieve relatively high IoU. For example, for the clock drawing test, if a heat map focuses on the locations outside the clock, the heat map is considered to have low visual explainability under this metric since it provides no visual cues to substantiate the model’s prediction. Note that the heat maps that yield higher IoU are not necessarily more interpretable by clinicians since the highlighted pixels could be randomly moved without decreasing the IoU as long as they are still located inside the ROIs.

For the expert ROIs, the ground truth ROIs were constructed based on the assumption that a good heat map should capture the presence of the paths that should not be drawn and/or the absence of the paths that should be drawn. So, the ground truth ROI for each image contains the regions that include such unusual paths as confirmed by an experienced clinician. This metric more closely resembles how clinicians would interpret the results in everyday practices. While a heat map that highlights unusual paths in the image would achieve high IoU with respect to the expert ROIs, a model that generates a heat map that focuses only on the usual paths (i.e., paths typically drawn by healthy aging controls) would achieve lower IoU. Therefore, IoUs with the expert ROIs should also be used and interpreted with caution.

留言 (0)

沒有登入
gif