EEG decoding for effects of visual joint attention training on ASD patients with interpretable and lightweight convolutional neural network

For better reading experience, Table 1 showed the abbreviations and parameter symbols involved in this paper.

Table 1 Meanings and values of the abbreviations and parameter symbolsParticipants

Participants included 15 adolescents and adults with ASD ranging in age from 16 to 38 years with average age 22.2 years (Full-Scale Intelligence Quotient (FSIQ): Mean = 102.53; SD = 11.64). ASD participants inclusion criteria: (1) Positive diagnosis of ASD assigned according to the gold standard instruments. (2) Parental or caregiver interview (autism transmission). (3) Direct structured subject assessment (autism diagnostic observation schedule) (Lord et al. 1999). (4) The current diagnostic criteria for ASD according to the diagnostic and statistical manual of mental disorders, fifth edition (DSM-5) (Edition 2013).

Participants were excluded if they had intellectual disability, with a FSIQ below 80 (Wechsler 2008) and associated medical conditions such as epilepsy, neurocutaneous, or other genetic known syndromes, or other usual comorbidity in ASD samples (Amaral et al. 2018).

Experiment setup

The virtual scene in the experiment contained a bedroom with objects (door, window, lamp, soccer, book, radio, printer, and laptop), furniture (bed, table, chair, shelf, and dresser), and one avatar. The BCI task was divided into three phases, the first two were offline, which were part of the BCI calibration process, and the last one was online.

In the first phase, participants were informed the target object directly to eliminate potential errors associated with the social attention deficit of ASD. In the second phase, participants were asked which objects the avatar had selected after it made an action to ensure that participants could read the cue correctly and were able to use this information correctly. In the third phase, participants were asked to respond to a cue from the avatar's head, namely, they were asked to imagine the avatar blinking at the stimulus target after focusing their attention during a trial. The first two phases had non-social characteristics, while the third phase had social characteristics.

In this experiment, all participants were required to do the same training session 7 times over a period of 4 months, the first four on a weekly basis and the last three on a monthly basis. Each session contained 20 blocks, each block consisted of 10 runs, and each run contained 8 trials in the virtual scene. During the experiment, the non-target stimulus or the target stimulus was flashed once at a fixed interval of 200 ms. Each flash lasted 100 ms. The experimental flow of one session was shown in Fig. 1. The EEG data acquired in session1 were considered as before joint attention training, and the EEG data acquired in session 7 were considered as after joint attention training.

Fig. 1figure 1

Diagram of the experimental composition of one session

Each block had one fixed target object (target object: the printer on the shelf, the laptop on the table, the soccer on the floor, the photo on the wall, etc.) with a total of 80 trials. The first 20 trials were calibration trials, which recorded EEG data including the P300 response generated when the target object flashed. Statistical classifiers were trained to identify P300. These classifiers were used in the online phase to identify whether the participant was responding to the flash of the object that the avatar interested. If the participant completed the action correctly, the BCI gave positive feedback (the object of interest turned green at the end of the trial). If not, that target object turned red.

Data acquisition and preprocessing

Experiments were conducted using the P300-based virtual reality BCI paradigm. Immersive virtual scenes were presented to participants via the Oculus Rift Development Kit 2 headset, and participants were required to wear both the headset and a 16-electrode wireless EEG cap during the experiment. Here, EEG data were recorded from the same 8 electrodes positions (C3, Cz, C4, CPz, P3, Pz, P4, POz). The reference position and the ground position were placed at the right ear and the AFz, respectively, and the sampling rate was set to 250 Hz. To improve the signal-to-noise ratio of the signal, a 50 Hz trap filter was used during acquisition of the EEG data. Then, the EEG data were band-pass filtered from 2 to 30 Hz, followed by segmentation of the data from 200 ms before stimulus onset to 1200 ms after stimulus onset.

Since the interval between two trials in the experimental design was only 300 ms, the two trials data before and after the appearance of the target stimulus were censored separately in this study in order to avoid the influence of the P300 component on the target stimulus. The labels in the data were defined as “1” for the target stimulus and “0” for the standard stimulus. The EEG data were also detrended and extracted from 100 ms before the stimulus onset to 700 ms after the stimulus onset, and the shape of the processed data was 8 * 201.

Model structure

To ensure the accuracy of decoding the P300 signal contained in the EEG of ASD patients, this study used EEGNet, which was validated in a previous study, as a classifier (Lawhern et al. 2018). EEGNet was a lightweight neural network with a strong generalization capability for EEG data processing. Figure 2 showed the model architecture of EEGNet, and Table 2 described the parameters of each layer of the network model applied in the current study.

Fig. 2figure 2

Architecture of EEGNet in the present study. White boxes displayed the information of the corresponding layers. The number before “@”: the number of filters per layer. The size after “@”: input data size for per layer. The rectangle indicated the convolution and the circle indicated the result after flattening. Different sizes represented different input sizes of these layers. Different colors represented different convolution kernel parameters were set

Table 2 Detailed parameters for EEGNet in the present study

In applying EEGNet, it was necessary to create a two-dimensional array of EEG data, where channels (8) and time (201) were represented by rows and columns of data, respectively, so that the shape of the input data was 8 * 201. EEGNet had two main blocks (refer to Table 2), 8 temporal filters and 16 spatial filters. The first block had convolutional filters whose kernel size was 1 * 64 to generate temporal feature mappings. Based on that, the spatial features were learned using depthwise convolution (Zhang et al. 2019) with a kernel size of 8 * 1. In this case, the number of spatial filters learned for each feature map was twice the number of temporal filters. After applying the temporal and spatial filters, the output features were normalized and exponential linear units (ELU) (Clevert et al. 2015) was used as the activation function. The distribution of the ELU output was zero-mean for faster training and unilateral saturation for better convergence. It was expressed as Eq. 1:

$$ f\left( x \right) = \left\c} x & }\;x \ge 0} \\ - 1} & }\;x < 0.} \\ \end } \right. $$

(1)

To reduce the dimensionality of the features, an average pooling layer with a step size of 4 and a kernel size of 1 * 4 was used to dimensionality reduction of the features. Also, to prevent overfitting, the dropout (Srivastava et al. 2014) function was used to deactivate the neurons with a certain probability. After the above steps, the shape of the output features of the first block was 16 * 1 * 50.

In the second block, 16 filters with a separable convolution of 1 * 16 kernel size were used to learn the depth-time features of the EEG signal. Since the separable convolution had fewer parameters than the normal convolution, the model was less prone to overfitting. Again, the ELU activation function was used after batch normalization, and then the features were downscaled using an average pooling layer with a step size of 8 and a kernel size of 1 * 8, and the dropout function was used to alleviate the overfitting problem during training. Finally, a dense layer with a softmax activation function was used to classify the data. Softmax was used in the multi-classification process, it mapped the data to the interval [0,1] and constructed the output value as a probability distribution. The larger the mapping value of the softmax activation function, the greater the probability of the true category. The formula for calculating the softmax function was given below (Eq. 2):

$$ f\left( x \right) = \frac }} }} e^ }} }}\quad }\;i = 1, \ldots ,n $$

(2)

In this paper, EEGNet was implemented using the Tensorflow framework, and Adam was chosen as the optimizer. The model learning rate was set to 1.25 × 10−4, β1 = 0.9, β2 = 0.999. The batch size for the mini-batch gradient descent was set to 128 examples, and the number of iterations was set to 300, while the neural network was trained using an early stop strategy. In addition, the network learned network parameters by back propagation. Due to the imbalanced number of positive and negative samples, we used focal loss as the loss function for training. The model was trained using the leave-one-subject-out method on NVIDIA GTX 1080, which meant that each subject's data were used for testing, the other remaining subjects' data were the corresponding training set. For a total of 15 subjects in this study, 15 training sessions and 15 testing sessions were required to obtain the corresponding classification accuracy on each subject's data. Finally, the average accuracy and the standard deviation of the classification accuracy obtained from 15 subjects’ test results.

Loss function

Since the experiment was designed using the oddball paradigm, it resulted in a dataset with an imbalance of positive and negative samples [positive samples (target): negative samples (non-target) = 1:7] obtained in this experiment. Because of the large difference between positive and negative samples, this led to the difficulty of the neural network to learn features with a small number of positive samples, making the data with a small number of positive samples frequently misclassified and reducing the classification accuracy. And focal loss (FL) was proposed to solve the problem of sample imbalance (Lin et al 2017).

Focal loss was used in the classification task of the present study. The purpose of focal loss was to solve the problem of data imbalance between positive and negative samples, which was improved by balancing the cross-entropy (CE) loss. We introduced the focal loss starting from the CE loss (Eq. 3) for binary classification. The following mathematical formula was used to express the expression of focal loss in a rigorous way.

$$ }\left( \right) = \left\c} & }\;y = 1} \\ \right)} & }} \\ \end } \right. $$

(3)

where \(y \in \left\ \right\}\) denoted positive and negative samples and \(p \in \left[ \right]\) denoted the probability estimated by the model for the class with label y = 1. For notational convenience, we defined \(_\) as Eq. 4:

$$ p_ = \left\c} p & }\;y = 1} \\ & },} \\ \end } \right. $$

(4)

$$ }\;}\;}\left( \right) = }\left( } \right) = - \log \left( } \right) $$

(5)

In the process of model training, the cross-entropy function treated the number of samples in each category as the same. However, the number of samples in each category was not always balanced in the model training process, and the cross-entropy weights each sample the same, which led to the category with more samples occupying most of the loss and their dominating the direction of model optimization. Eventually, it would lead to the classification effect of the hard-to-classify samples was not optimized.

To solve the above problem, a modulation factor \(}_}\right)}^\) was added in front of the CE loss, where γ was called tunable attention parameter, and γ ≥ 0. Therefore, the focal loss was defined as Eq. 6:

$$ }\left( } \right) = - \left( } \right)^ \log \left( } \right) $$

(6)

Therefore, two properties of the focal loss were noted: (1) When the sample was misclassified and \(_\) was close to 0, the modulation factor was close to 1 and the FL was unaffected. As \(}_}\) neared 1, the factor neared 0 and the FL for well-classified examples was down-weighted. (2) When γ = 0, FL was equivalent to CE, and as γ was increased the effect of the modulation factor was likewise increased. It was found that γ = 2 worked best in previous studies (Lin et al 2017), so we chose γ = 2.

Saliency map

Neural network had the ability to automatically extract features, and the visualization of neural network could facilitate the understanding of the EEG features learned by EEGNet. These features drove the model to make the appropriate decisions. In this study, the saliency map was used as a visualization technique for networks (Arrieta et al. 2020), which was widely used in various fields. The saliency map (Simonyan et al 2013) was a representation of the features learned by the model in terms of gradients (output relative to the input). After training the model and fixing its weight, the gradients relative to the input were back-propagated back to the first layer of the neural network.

In the convolutional neural network, given an input EEG data X0, a class c, and a classification EEGNet with the class function \(h_ \left( X \right)\), we would like to rank the values of X0 based on its influence on the \(h_ \left( } \right)\). We started with a straightforward example. Consider the linear model for the class c as Eq. 7:

$$ h_ \left( X \right) = \omega_^ X + b_ $$

(7)

where the EEG data X was a representation of a one-dimensional vector form, \(\omega_\) and \(b_\) denoted the weight vector and bias of the model, respectively. From Eq. 7, it was easy to see that the magnitude of ω defined the importance of the corresponding X for the class c.

In deep neural network, the \(}_}\left(}\right)\) was a highly non-linear function of X, so the reasoning in the previous paragraph did not have good applicability here. However, the Taylor function was an approximate function that fitted well for deep neural network, so given an EEG data X0, we can approximate \(}_}\left(}\right)\) with a linear function in the neighbourhood of X0 by computing the first-order Taylor expansion with the expression as Eq. 8:

$$ h_ \left( X \right) \approx \omega^ X + b. $$

(8)

In the above Eq. 8, ω was the derivative of the function \(}_}\) at X0. The expression was as Eq. 9:

$$ \omega = \left. }}} \right|_ }} $$

(9)

here the classification was based on the similarity of the input data for each category. The magnitude of the derivative ω indicated which X0 need to be changed the least to affect the class result the most.

In this paper, the derivative ω was found by back propagation. After that, the saliency map was obtained by rearranging the elements of the vector ω. The number of elements in ω was equal to the number of pixels in X0, so the map can be computed as \(M_ = \omega_\), where \((i, j)\) was the index of the element ω, corresponding to the pixel in the i-th row and j-th column. We used the saliency map method for within-subject feature visualization. First, the within-subject saliency map was calculated for data from the same cognitive task. Then, the resulting saliency maps were normalized. Next, the normalized saliency maps were averaged by superimposing them. Finally, the average saliency map of all subjects was obtained.

留言 (0)

沒有登入
gif