Deep Convolutional Neural Networks on Multiclass Classification of Three-Dimensional Brain Images for Parkinson’s Disease Stage Prediction

PreprocessingNormalization

Because the ranges of the pixel intensity values for each patient differed, potentially affecting the efficiency of machine learning algorithms, we scaled them by using min–max normalization individually:

$$}}_\left(\text\right)=\frac}}_-\text\left(}}_\right)}\left(}}_\right)-\text\left(}}_\right)}$$

where \(}}_\) represents the pixel values in all slices of the \(i\) th patient, and \(\text\left(}}_\right)\) and \(\text\left(}}_\right)\) are the smallest and largest pixel value of \(}}_\), respectively. After rescaling was conducted, the values all ranged from 0 to 1, and the relative magnitude among patients could also be maintained.

Slice Selection

For the Chang Gung data set, during diagnosis, physicians had selected the slices whose striatum can be recognized most clearly. The data set contained the indices provided by physicians for selected slices, with nine slices being selected for each patient. As for the E-Da data set, because physicians did not provide the information about which slices they considered, we used the slices selected by the authors from our previous work [20], where one slice that contained the clearest striatum shape was selected for each patient.

We developed an automatic process to filter all incomplete slices; that is, slices where parts of the image were masked, or the pixel values were all very close to zero. These slices provided no or little information, or they may disturb the training process. Although the slice size in both data sets was \(128\times 128\) pixels, the proportions that the brain accounted for were different: the proportions in the Chang Gung data set were larger than those in the E-Da data set (Fig. 3). We thus removed the slices whose number of pixels with normalized intensity value greater than 0.1 was less than 800 and 400 in the Chang Gung and E-Da data sets, respectively.

Other Information

Epidemiological studies have indicated the existence of sex differences in PD: the incidence and prevalence of PD in men are higher than those in women [27]. PD is also related to age: the incidence rates rise rapidly after the age of 60 years [28]. Moreover, the male:female incidence rate ratio increases with age in general. For the aforementioned reasons, age and sex were considered during the training process.

Because ages in this data set were positive numbers in years, they needed to be normalized to range \(\left[0, 1\right]\) to ensure that the scale would be the same as that of the pixel values of slices. The normalized age is expressed as

$$\begin_=_/100& _\in \left[0, 100\right]\end$$

where \(_\) represents the age in years of the \(i\) th patient.

Sex only took two values, male or female; thus, it was represented by a dummy variable defined as

$$G=\left\1& \text\\ 0& \text\end\right.$$

Augmentation

Because the two data sets were both imbalanced and too small for deep learning models, overfitting was a concern; that is, the model may fit the training data excessively, consequently leading to poor performance when predicting new data. To avoid this problem, the most common approach for image classification tasks is to perform data augmentation, which provides new data by making minor alterations to the existing dataset. The augmented data were generated using video transform and depth transform, as described in the following section. We used online augmentation; that is, image augmentation was conducted in mini-batches before feeding data to the model. The model with online augmentation was presented with different images at each epoch; this aided the generalization ability of the model, and because the model did not need to save the augmented images on the disk, the computational burden was reduced.

Video Transform

Because our data were 3D SPECT images, which can be regarded as a series of slices (2D images), we adopted video transforms to ensure that the same random parameters (e.g., crop size, rotation angle, flip, or not) would be applied to all the slices of each patient:

o

Random rotation (\(\text=5\)): rotate the image by an angle ranging from \(-^\) to \(^\)

o

Center crop (\(\text=\left(72, 72\right)\)): crop the given image at the center, with the desired output size being \(72\times 72\)

o

Resize (\(\text=(72, 72)\)): resize images to \(72\times 72\)

Depth Transform

Video transform yielded images with a size of \(72\times 72\); however, the number of slices for each patient was different. We applied the trilinear interpolation approach [29] to construct new slices to unify the depth dimension of all patients’ images.

We set the target slice number to 32 because most of the samples in our data sets had fewer than 32 slices, and thus we would not lose too much information. After we applied trilinear interpolation, the data size of each patient was unified as \(32\times 72\times 72\).

Multiclass Classification

The main objective of this study was to develop a valid model to yield accurate prediction of PD illness stages. Because each sample can only belong to one of \(C\) PD stages (including healthy) (\(C=4\) for the Chang Gung data set and \(C=6\) for the E-Da data set), this was a multiclass classification task. We conducted a deep learning model that mapped inputs of the \(i\) th image \(}}_\) to a \(C\)-dimentional label vector \(}}_=\left(_,_,\cdots ,_\right)\) with \(\forall _\in \left\\right\},_+\cdots +_=1\). Therefore, the categorical cross-entropy can be used as the loss function:

$$}_=\sum_\left\^\left[-_\text\left(s\left(_\left(}}_\right)\right)\right)\right]\right\}$$

where \(_\left(}}_\right)\) is \(}}_\)’s \(c\) th input for the final fully connected layer and \(s\left(_\right)=^_}/_^^_}\) is the softmax function.

Our datasets were highly imbalanced. Most machine learning algorithms for classification problems operate under the assumption of an equal number of samples in each class. In an unbalanced data set, the learning can be biased toward the majority classes and fail to catch the patterns of the minority classes. A popular method to address imbalanced data is to set the class weight in the loss function. We defined our class weights as

$$_=\frac_}_^_}\text_=\frac_}, c=1,\cdots ,C$$

where \(n\) represents the total number of samples, and \(_\) is the number of samples in the \(c\) th class (PD stage). The loss function with class weights thus becomes

$$}_}=\sum_\left\^\left[-__\text\left(s\left(_\left(}}_\right)\right)\right)\right]\right\}$$

Deep CNN Models

Image classification using CNN has demonstrated outstanding performances compared with traditional machine learning approaches. CNNs are designed to automatically and adaptively learn spatial hierarchies of features through back-propagation over multiple stacked building blocks, such as convolution layers, non-linear layers (activations), pooling layers, and fully connected layers. CNN is a deep learning method designed to process structured arrays, and it has become dominant in various computer vision tasks. In addition to their application to classification, segmentation, and recognition problems related to image or video, CNNs have also been applied to natural language processing.

This study focuses on multiclass classification of 3D SPECT images. Traditionally, 3D images are analyzed using either 2D CNNs on individual slices (2D models) or 3D CNNs on the entire spatiotemporal volume (3D models). We introduce an architecture that uses an attention mechanism to consider the relationships among slices, allowing each slice to contribute differently to the prediction (slice-relation-based models). Besides training the Chang Gung and E-Da data sets separately with various model architectures, we propose cotraining the two datasets simultaneously using weight sharing to enhance model effectiveness and robustness.

Transfer learning is the process of creating new models by fine-tuning previously trained networks, where a model trained for one task is utilized as the initial value of the model for another related task. Transfer learning has been widely applied to diverse tasks. One of its strengths is particularly helpful for our task: transfer learning reduces the need and effort to recollect a large amount of training data, thus addressing the challenge of training a full-scale model from scratch or with little data. In our transfer learning training process, we used 2D and 3D model architectures trained on ImageNet [30] and Kinetics-400 [31], respectively, as the pretrained model, froze the weights of the first few layers, set the remaining layers as trainable, and then replaced the final few layers with customized layers. Given that ImageNet and Kinetics-400 were trained on natural RGB images, some studies demonstrated that, in different medical imaging applications, CNNs pretrained on these large-scale natural image data sets performed better and were more robust to the size of the target training data than the CNNs trained from scratch [32, 33].

2D Models

3D SPECT data can be regarded as a series of 2D images and thus can be analyzed by 2D model architectures. We selected the VGG-16 architecture pretrained on ImageNet as our pretrained model, and we used its convolution layers connected with some customized layers to form our training model.

Because the images in ImageNet were RGB images with three channels, we copied each slice three times to generate the three-channel input from our gray-scale medical images, where each input was formed by a set of three repeated slices. All slices passed through shared VGG-16 convolutional layers to obtain image representations; these outputs were then “summarized” through some customized layers before proceeding to the final fully connected (FC) layers. We discarded the last max-pooling layer in the convolution layers of the original VGG-16 architecture to ensure that the output shape would not be too small. The different customized layers we used are described in the following sections.

VGG plus linear (Linear): The first model was the simplest one and also the baseline model we used for comparison. We added an adaptive average pooling layer (with target output size \(1\times 1\)) to features with size \((512, 4, 4)\) and obtained outputs of size \((512, 1, 1)\). After the outputs from all slices were averaged, a multilayer perceptron (MLP) with rectified linear unit (ReLU) activation was used to reduce the output dimension from 512 to 16. Age and sex were concatenated with the output features if necessary and then fed into the FC layers. In this manner, the effects of age and sex would be appropriately considered. The model architecture is presented in Fig. 4a.

Fig. 4figure 4

2D model architecture. The yellow block represents the modified VGG-16 architecture we used. In c, “Slices” represents the collection of all slice sets form the same person. a Linear. b Conv2D. c ACS

VGG plus Conv2D (Conv2D): For more advanced processing, we replaced the average pooling layer in the aforementioned Linear model architecture with a 2D convolutional layer. This enabled us to extract more realistic features. This model architecture is presented in Fig. 4b.

Axial–coronal–sagittal convolutions (ACS): Because any of the three views (axial, coronal, and sagittal) of our 3D data can be regarded as 2D images, ACS convolutions [25] were utilized in this model. In ACS convolutions, 2D kernels are split by the channel into three parts and convoluted separately on the axial, coronal, and sagittal views of 3D inputs; as a result, the weights pretrained on large 2D data sets could still be used [34]. The output feature of each view was concatenated after an average pooling layer to obtain the overall output feature and finally went through the FC layers (Fig. 4c).

3D Models

Because we were working with 3D medical images, the application of 3D models was also appropriate. Here, we employed the 3D models for action recognition and video classification, as in Tran et al. [23], based on the ResNet-18 architecture. These models were all pretrained on the Kinetics-400 dataset [31], containing 306,245 short-trimmed, realistic action clips from 400 action categories.

Notably, the input shape for these 3D models becomes \((C, T, H, W)\), where \(C\) is the channel and \(T\) is the number of video frames in a clip. All of our slices were gray-scale with one channel, whereas each frame of a clip in Kinetics-400 was an RGB image with three channels. For data set compatibility, each of our slices was repeated three times, and then video transform and depth transform were applied to ensure that all the inputs had the same shape.

The overall architectures of the following 3D models were similar: a 3D pretrained model was followed by an average pooling layer and some MLPs. However, different pretrained models were used for each 3D model.

3D ResNet (R3D): Unlike 2D CNNs, 3D CNNs can preserve and capture temporal, or positional in our case, information because the filters are convoluted over both time and space dimensions. In this model, the 2D convolutions were merely replaced with 3D convolutions, with an extra temporal extent on the filter. The overall model architecture is illustrated in Fig. 5a.

Fig. 5figure 5

3D model architecture. a R3D. b MC3. c R(2 + 1)D

Mixed convolutions ResNet (MC3): On the basis of the hypothesis that temporal modeling with 3D convolutions would be useful only in the early layers and would be unnecessary in the late layers because the higher-level features are abstract [23], the mixed 2D–3D convolutions were used. Considering that the R3D model had five groups of convolutions, we applied the MC3 model where the 3D convolutions in groups 3, 4, and 5 were replaced with 2D convolutions. The architecture is presented in Fig. 5b.

(2 + 1)D ResNet (R(2 + 1)D): A 3D convolution can be approximated by a spatial 2D convolution followed by a temporal one-dimensional (1D) convolution, where the model captures spatial and temporal information from two separate steps. This process facilitates the optimization and doubles the number of nonlinear transformations caused by the additional activations between 2 and 1D convolutions without increasing the number of parameters. In this model, the (2 + 1)D convolutions are substituted for 3D convolutions, and the architecture is illustrated in Fig. 5c.

Slice-Relation-Based Models

We initially assumed that all the slices from one patient contributed equally in the training process of the 2D models because we merely used the average of their convoluted features. However, during diagnosis, physicians only focus on slices whose striata are clear enough to be recognized; thus, treating all the slices of one patient equally is not ideal. We therefore also considered the relation among slices to assist models to learn the differences among slice contributions.

Index embedding (IdxEmb): First, the slices of one patient were sequentially numbered according to the order of their entry into the model, called slice index. The slice index was regarded as a categorical variable \(x\) and then mapped to a vector \(}\) with a predefined dimension through entity embedding [35]. Entity embedding was used to operate a linear layer on the one-hot encoding of \(x\) as follows:

where \(m\) is the number of possible values for the categorical variable \(x\), \(}}_\) is a vector of length \(m\) with the \(i\) th element being 1 if \(i=x\) and 0 otherwise, \(i=1, \cdots ,m\), \(}=\left\_\right\}\in }^\) is the weight matrix that connects the one-hot encoding with the entity embedding, and \(k\) is the dimension of the entity embedding. \(}\) can be regarded as the weights of the linear layer, which were trainable by using a standard backpropagation method. The embedded index \(}\) was reshaped and added to the convoluted features of the corresponding slice and went through a 2D convolution together. The output features of these slices were then gathered and averaged and finally fed into a classifier. The whole model architecture is presented in Fig. 6a.

Fig. 6figure 6

Slice-relation-based model architecture. a IdxEmb. b Attn. c MH-Attn

Attention (Attn): Adding the information about the position of slices in the slice index embedding model seemed appropriate, but the image registration problem proved challenging: we should have the absolute position for slices. Unfortunately, our data were functional images whose positional information was inaccurate; for example, slice index 1 might correspond to different brain areas for different patients. To avoid the image registration problem, we only considered the relative importance of all slices in one patient, an approach motivated by the attention mechanism [36, 37]. Because our attention mechanism only considered the input images of a patient with no extra output information and aimed for capturing the internal correlation of slices within a patient, we were in fact implementing self-attention [38, 39].

Overall, we adopted the 2D model but, instead of taking the average, we used the weighted sum (Fig. 6b). Details for the weighted sum are as follows. For one patient, each VGG-16 convolution output feature \(}}_\) went through a convolution layer with ReLU activation and yielded another output \(_\) of size \(\left(\text,1\right)\). We applied the softmax function to outputs \(_,\cdots ,_\) to form the weight for each output feature \(}}_\): \(_=^_}/_^^_}\), where \(m\) is the number of slice sets. Subsequently, the input of the following MLP (with the ReLU activation) can be represented as follows: \(}=_^_}}_\). In this manner, the model allowed each patient to have a different weighting scheme (i.e., different \(_,\cdots ,_\)) with unique important slices.

Multihead attention (MH-Attn): The aforementioned attention mechanism focuses on a specific aspect that the image slices of a patient reflected. However, multiple aspects in these image slices together form the overall structure of the 3D image. Multihead attention allows the model to jointly attend to information from different aspect subspaces for different slices. We here adopted the multihead attention formulation used in Vaswani et al. [39].

After going through the VGG-16 convolutional layers and an average pooling layer, each output feature \(}}_\) (with dimension 512) had three (randomly initialized) weight matrices to be multiplied separately: query weights \(}}^\), key weights \(}}^\), and value weights \(}}^\). The multiplications then yielded \(}}_\) (query), \(}}_\) (key), and \(}}_\) (value). Thus,

$$}=}}}^, }=}}}^, }=}}}^$$

where \(}}^,}}^,}}^\in }^\), \(}=}}_,\cdots ,}}_\right)}^\), \(}=}}_,\cdots ,}}_\right)}^\), \(}=}}_,\cdots ,}}_\right)}^\), and \(}=}}_,\cdots ,}}_\right)}^\). Each query \(}}_\) multiplied the key matrix \(}\) to generate its attention scores. These attention scores went through a softmax function to create the weights for the weighted sum over \(}}_,\cdots ,}}_\), where the weighted sum was the attention vector \(}}_\). The aforementioned process can be expressed as

$$\text\left(},},}\right)=\text\left(}}}^\right)}=}$$

where \(}=}}_,\cdots ,}}_\right)}^\). One set of \(\left(}}^,}}^,}}^\right)\) matrices were called an attention head. Multiple attention heads enable the model to learn different relevance among slices. We used four heads (based on our empirical results), and the independent outputs of these heads were simply concatenated and transformed into the desired dimension 512 by using matrix multiplication. The multihead attention mechanism can be expressed as

$$\text\left(},},}\right)=\text\left(}_,\cdots ,}_\right)}}^$$

where

$$}_=\text\left(}}}_^,}}}_^,}}}_^\right), i=1,\cdots ,4$$

and \(}}^\in }^\). The output shape became \(\left(m, 512\right)\) after the application of this multihead attention mechanism. The sum of these row vectors then served as the input to the following MLP (with ReLU activation). The overall model architecture is illustrated in Fig. 6c.

Cotraining Models

Two datasets used in this study differ significantly in how they define PD illness stages: the Chang Gung dataset includes four stages (including healthy cases), while the E-Da dataset uses six stages. This discrepancy makes it challenging to directly train on one dataset and validate on the other without additional alignment or adjustments.

We initially trained our two datasets (Chang Gung and E-Da data sets) separately with different model architectures. Although we obtained some favorable results, we still did not leverage the high similarity between them: they were both SPECT brain images and thus some shared characteristics should be present. Thus, we subsequently trained these two data sets jointly by using the models introduced in the preceding sections. The schematic is illustrated in Fig. 7. The 2D and 3D pretrained models were still used, but the inputs were evenly composed of the two data sets in each batch. The shared model weights (i.e., the yellow block in Fig. 7) were first cotrained using inputs evenly distributed from both datasets, ensuring that the model could learn features applicable across datasets. The number of cotraining layers was a hyperparameter to be tuned. After this shared training phase, the remaining components of the model were trained separately for each dataset, optimizing for their respective loss functions to account for the differences in stage definitions. The loss \(\mathcal\) can be expressed as

where \(}_}\) and \(}_}\) refer to the weighted categorical cross-entropy loss of the Chang Gung and E-Da data sets, respectively. In this manner, the two data sets would share the weights of the cotraining part, which increased the robustness of the model because of the increased variation and number of training samples. The model could also retain the differentiation between the two data sets through the parts trained separately.

Fig. 7figure 7

Schematic of cotraining models

EvaluationK-Fold Cross-Validation

To assess the generalization ability of the models and to avoid overfitting and selection bias, we adopted stratified cross-validation. The number of folds was set to be five, which meant the original samples were randomly partitioned into five groups, with each group containing approximately the same proportions of class labels. In each cross-validation round, one group was for testing (called the test set), and the other four groups were for training (called the training set); 20% of the training set served as the validation set.

Metrics

We used two metrics, namely, accuracy and F1 score, to evaluate our model performance. Accuracy is defined as the proportion of samples that are correctly classified. This evaluation metric is commonly used in classification tasks. However, a problem with imbalanced data sets is that model accuracy can be judged as high even if it only correctly predicts the majority class. Because our data sets were imbalanced, we employed another metric to evaluate model performance: F1 score. For each class, we are interested in the fraction of relevant samples among the retrieved samples (i.e., precision) and the fraction of retrieved samples among the relevant samples (i.e., recall). The F1 score is the harmonic mean of the precision and recall, with its optimal and poorest values being 1 and 0, respectively. In our task, we first calculated the F1 score for each class and then took the average of these F1 scores, called macro F1 score, as our alternative evaluation metric.

留言 (0)

沒有登入
gif