DWFed: A statistical- heterogeneity-based dynamic weighted model aggregation algorithm for federated learning

1. Introduction

As the function of mobile devices, wearable devices, and IoT devices has become more diverse and complex than ever, a tremendous amount of valuable data is generated all the time locally, and huge potential information can be mined through a well-trained statistical model. However, traditional centralized model training requires collecting data in a central node to extract features, which consumes a large amount of time for data transmission and model training because of the tremendous data across the devices. Additionally, it could also cause privacy leakage of sensitive data during transmission. Therefore, federated learning (Konečnỳ et al., 2015; McMahan and Daniel Ramage, 2017; McMahan et al., 2017), a distributed machine learning framework that involves a central server and multiple remote devices, is proposed to address the challenges that centralized methods are confronted with. It enables remote devices to train statistical models locally and only share the parameters of the model to a central server for the aggregation of the federated model, thus providing faster construction of the federated model and privacy of data. Due to these advancements of federated learning, it has been continuously improved and applied in many fields, including smart healthcare (Shamshirband et al., 2021; Rahman et al., 2022; Samuel et al., 2022), industrial internet of things (Sun et al., 2020; Yang et al., 2022), etc. However, federated learning is still confronted with the challenges of model transmission cost and statistical heterogeneity. Specifically, as the parameters of the statistical model are always with high dimensions, frequent parameter uploading can consume lots of transmission time, leading to the low efficiency of federated model training. Besides, statistical heterogeneity results from the non-IID data generated by different devices, which holds various features or labels probability distribution, is proven to have a negative impact on model convergence and accuracy compared with IID data.

To address these challenges, current researchers have proposed several optimization algorithms based on federated learning. Specifically, federated averaging (FedAvg) (McMahan et al., 2017) is such a typical algorithm, which deploys several rounds of local stochastic gradient descent (SGD) on each device and then uploads the parameters of the model to a central server for the model averaging. Several experiments on public benchmark image classification data set (MNIST LeCun et al., 1998, CIFAR-10 Krizhevsky, 2009) and language data set (Shakespeare, 2007) have demonstrated the robustness of FedAvg to train convolutional neural networks (CNN) and long short-term memory (LSTM). However, recent research has found that the statistical heterogeneity caused by non-IID data will increase the model divergence, representing the difference between federated and centralized models, leading to significant accuracy reduction and unstable convergence of federated model.

The research of federated learning dealing with non-IID data mainly focuses on the non-IID label distribution of the data across the clients. To improve the performance of the federated model confronted with non-IID data, Zhao et al. (2018) proposed a data-sharing-based method, which significantly improves the performance of federated average dealing with non-IID data by sharing a small amount of data. In addition, the relation between statistical heterogeneity and earth's mover distance (EMD) is found in their research, which indicates EMD could be an ideal index of statistical heterogeneity. This discovery motivated us to propose DWFed, a dynamic weighted model aggregation algorithm based on a federated averaging algorithm, which quantifies the index of statistical heterogeneity based on EMD, and dynamically computes the weights of model averaging based on the index to minimize the model divergence during federated model training. The most significant difference between FedAvg and DWFed, which is also the main contribution of this paper, is the weights given to the models uploaded by each device. In FedAvg, the weights are simply calculated by the ratio of the data on each device to the total amount of data. The averaging model can represent global optimization objects in IID settings. However, the performance of FedAvg can incredibly shrink as data becomes non-IID because non-IID data makes the weighted sum of local optimization object no longer an unbiased estimation of global optimization object. To overcome the drawback of FedAvg, DWFed calculated weights based on the indexes of statistical heterogeneity called ISH that we quantitatively define through derivation for the first time and is calculated by the EMD between local label distribution and global label distribution. DWFed can well resist the negative impact of non-IID data, and it brings little computation burden to each device as the calculation of weights is simple. However, as each client needs the global sharing label distribution information to calculate its own EMD, respectively, DWFed can better perform in the scenarios where the label information of data is not sensitive, such as hospitals, public driving locations and so on. The detailed introduction of DWFed will be illustrated in Section 3. In addition, experiments on multiple benchmark data sets reveal the improvement of performance and robustness on federated models trained with non-IID data compared with FedAvg. The main contributions of our work are summarized as follows:

1. We quantitatively studied the impact of statistical heterogeneity on federated learning through derivation for the first time.

2. We proposed an index of statistical heterogeneity called ISH, which would decrease as statistical heterogeneity increases.

3. We design a method to dynamically compute model averaging weights by using the index of statistical heterogeneity, which can effectively constrain the model divergence during federated model training.

The rest of our paper is organized as follows. In Section 2, the background and related work of federated learning and the corresponding optimization method is illustrated. The principle of DWFed and its derivation is demonstrated in Section 3. Experiments and evaluations are illustrated in Section 4. Finally, the conclusion of our work is given in Section 5.

2. Related work

The notion of federated learning was first introduced in McMahan and Daniel Ramage (2017), and its baseline algorithm is federated stochastic gradient descent (FedSGD), which enables each device to execute one round of SGD locally and upload the model to a central server for weighted model averaging. Then central server distributes the aggregated model to each device for the next round of local SGD, and the whole procedure stops until certain termination conditions are met. Although FedSGD solved the challenges of data transmission and privacy leakage of sensitive data (Bharati and Podder, 2022; Bharati et al., 2022), frequent model uploading and distribution have greatly constrained the performance of federated learning, including slow convergence and low accuracy, and results in the problem of efficiency.

To address the existing challenges, lots of constructive work has been done. In terms of the efficiency of federated learning, Wang et al. (2019) introduced adaptive federated learning that can dynamically compute communication steps with the central server (the rounds of local SGD) in resource-constrained edge computing systems. Faster convergence can be achieved compared with methods where the communication step is fixed. Also, starting from the communication cost, Konečnỳ et al. (2016) greatly reduces the communication cost by utilizing model compression, which decreases the size of the uploading model. Similarly, Sattler et al. (2019) proposed a compression framework called sparse ternary compression (STC), which extends the existing compression technique by enabling downstream compression as well as internalization and optimal Golomb encoding of the weight updates. Additionally, Asad et al. (2020) introduces an algorithm combined with model compression and parameter encryption, which effectively reduces communication overhead while protecting model security. Except for directly reducing communication costs, the efficiency of federated learning could also be improved by resource optimization. For example, Nishio et al. (2013), Sardellitti et al. (2015), and Yu et al. (2016) minimize the computation time and resources consumption based on the joint optimization of heterogeneous data, computation, and communication resources. In contrast, Nishio and Yonetani (2019) maximizes the efficiency of federated model training through client selection based on resources, network conditions, and computation capability, and experiments have proved the enhancement of efficiency.

In terms of robustness in non-IID data, plenty of solutions have been proposed by existing researchers, and we summarize the current federal learning scheme for data heterogeneity in Table 1. For example, Konečnỳ et al. (2015) proposed an optimization algorithm called DSVRG in order to promote the performance of federated learning in non-IID scenarios, in which the distributed optimization algorithm DANE (Shamir et al., 2014) is modified by utilizing SVRG (Johnson and Zhang, 2013) as a local solver to produce an approximate solution for the subproblem of DANE. In addition, some important modifications are taken to improve robustness in federated scenarios, such as flexible local update stepsize and applying the diagonal matrix to adjust the update stochastic gradient value of model. The experiments revealed that DSVRG not only accelerates the convergence but also decreases the test error ratio of federated learning. In 2017, an improved algorithm based on FedSGD called FedAvg (McMahan et al., 2017) is proposed. FedAvg allows devices to synchronously execute several epochs of SGD before uploading the model to a central server for model aggregation, and the convergence of FedAvg is theoretically proved in Li et al. (2019). Experiments on public benchmark data sets also demonstrate that FedAvg has the ideal convergence speed and robustness of training different deep learning models. However, Zhao et al. (2018) found that the performance of FedAvg gradually shrinks as statistical heterogeneity increases. In addition, mathematical analysis is utilized, and the relation between the earth's mover distance of each device and model divergence caused by heterogeneity is discovered. Therefore, a strategy that eases model divergence by sharing a small part of data from the central server to each client is proposed, and experiments have shown that the more data the central server shares, the lower EMD becomes, and the higher accuracy can be obtained. However, the specific mathematical relation between EMD and statistical heterogeneity is not further studied. Chen et al. (2022) proposed an adaptive client selection algorithm ACSFed based on EMD. This algorithm can dynamically calculate the possibility of clients being selected according to the local statistical heterogeneity and previous training performance. Similar to literature (Zhao et al., 2018), an adaptive enhancement method based on data sharing is also proposed in Huang et al. (2018), which improves the efficiency of federated learning. However, data sharing increases the communication burden and raises the risk of privacy leakage. It also breaks the core of federated learning that data should be stored locally instead of sharing. Therefore, recent research has begun to study approaches that can obtain better performance than FedAvg while keeping data locally. For example, Yeganeh et al. (2020) proposed a novel adaptive weighting approach for clients based on meta-information and the comparison with the baseline FedAvg algorithm proves the effectiveness of the scheme. Li et al. (2018) proposed a framework called FedProx, which changes the optimization object by adding the model divergence to the loss function. Experiments prove it can effectively stabilize the training convergence of the federated model because it constrains the difference between the central and local models. Moreover, a creative approach called federated augmentation, which makes data distribution IID on each device by enabling devices to train generative models together to augment data, is proposed in Jeong et al. (2018), and it obtains 95−98% accuracy on MNIST. Xu et al. (2022) proposed a federated learning framework FedLA, which reduces aggregation frequency to improve robustness in heterogeneity scenarios. Furthermore, the cross device momentum (CDM) is implemented to improve the upper limit performance fo the global model. Besides, there is also the idea of dealing non-IID data by combining reinforcement learning with federated learning. For example, Wang et al. (2020) proposes Favor, an experience-driven control framework that intelligently chooses the client devices to participate in each round of federated learning to counterbalance the bias introduced by non-IID data and to speed up convergence. Similarly, Pang et al. (2020) proposed an RL-based intelligent central server with the capability of recognizing heterogeneity, which can help lead the trend toward better performance for most of clients. In 2019, knowledge distillation was applied in federated learning in Li and Wang (2019), which enables each device to train a local model with two parts of data, including private data and public shared data. The outputs of public data are utilized as consensus to adjust each local model, and experiments have shown that the performance of FedAvg can be improved by implementing knowledge distillation. Additionally, there are methods that utilize multi-task learning in federated learning, which are called federated multi-task learning. In federated multi-task learning framework, the learning problem of each client on the local data set is regarded as a separate task rather than a shard of a partitioned data set. MOCHA (Smith et al., 2017) is a typical multi-task federated learning algorithm, which directly solves the challenges of communication efficiency, scatters and fault tolerance. On the basis of MOCHA, Li et al. (2021) proposed a lite framework called Ditto, which simplifies the solver of local subtask by restraining the divergence between local model and global model. Although Ditto's idea of restraining divergence between local model and global model is similar to FedProx, it is essentially different from FedProx, as it not only learns a global model but also learns local, personalized models while FedProx only learns a global model. Experiments on public benchmark dataset reveal that Ditto can enable higher accuracy and stronger robustness relative to state-of-the-art federated learning method. However, as multi-task learning enables each node to train a personalized model locally, a stateful node is also required. This makes this type of technology more challenging to apply in cross-device scenarios. To sum up, there are problems of higher computing and communication burden, privacy leakage, and difficulty in practical application in current research. Therefore, an improved federated learning method that can suppress or solve the above problems while retaining performance must be studied.

www.frontiersin.org

Table 1. Federated learning for data heterogeneity.

3. Method

To promote the performance of federated learning methods in statistical heterogeneity scenarios, we proposed a dynamical weighted model aggregation algorithm for federated learning called DWFed. The core idea of DWFed is to dynamically calculate the weights of model averaging by using the index of statistical heterogeneity ISH. In this section, we will first introduce the core of DWFed in detail, which is the derivation of the index of statistical heterogeneity, and then the overall of DWFed will be demonstrated.

3.1. Derivation of model divergence

During federated model training, K devices from N (K < < N) are randomly selected and then certain epochs of local stochastic gradient descent (SGD) are executed before uploading model to central server for model aggregation. Specifically, the optimization object is to minimize:

minωf(x)=∑k=1KnknFk(ω)whereFk(ω)=1nk∑s∈Skfk(ω)    (1)

Where Sk is the set of indexs of data points on client k, nk = |Sk| is the data available on device k, and n=∑knk is the total data points across the network, fk(ω) refers to the value of loss function of the data on device k under the model ω. The procedure of typical federated learning method with K selected devices, batch size b and learning rate η enables device k to iterate local update ωk, t−ηgk several times, where gk = ∇Fk(ωk, t) is the gradient computed by the current model ωk, t on device k, and ωk, t = ωt when the local update begins. After K devices finishing local update and uploading model ωk, t+1 to central server, model aggregation ωt+1=∑k=1Knknωk,t+1 is executed on central server, which can also be rewritten as ωt+1=ωt-η∑k=1Knkn∇Fk(ωk,t).

In IID settings where training data is uniformly and randomly distributed to each device, the expectation of Fk(ω) is equal to f(ω), which can be denoted as E(Fk(ω)) = f(k), and thus E(gk) = ∇f(ω). Therefore, the optimal solution can be obtained by updating the model along the descent direction of the gradient and the federated model generated by averaging local models is nearly equal to the centralized model. However, Fk(ω) could be an arbitrary approximation to f(ω) in non-IID settings, leading to the deviation between federated model and centralized-trained model, which is called model divergence and it can be represented as:

||ωf-ωc||/||ωc||    (2)

Where ωf is the model in distributed settings using federated learning method, and ωc is the centralized-trained model. The more significant statistical heterogeneity is, the larger the model divergence is, and the performance of FedAvg can extremely shrink. Therefore, a numerical index of statistical heterogeneity is urgently needed to precisely reflect its influence on the performance of federated learning methods.

3.2. Derivation of statistical heterogeneity influence

Through the above derivation and analysis, it can be concluded that the model divergence caused by non-IID data is the main reason leading to decreasing performance of federated learning methods in statistical heterogeneity scenarios. Therefore, we propose a dynamic weighted federated averaging algorithm (DWFed) based on FedAvg which quantitatively defines the index of statistical heterogeneity for the first time and dynamically computes the corresponding weights of model averaging to constrain model divergence. The core idea of DWFed is to calculate comprehensive weights based on the statistical heterogeneity of each selected device and hyperparameters such as learning rate, batch size, and the number of selected devices that are able to make a federated model close to the centralized model and thus constrain the model divergence. Specifically, the centralized model update using SGD can be written as:

ωt+1c=ωtc-η∑i=1CP(y=i)∇F(ωtc,xy=i)    (3)

In the above equation, ωt+1c and ωtc are the weights after t+1-th update and t-th update respectively, η is the learning rate, P is the data distribution which is also the population distribution, and C denotes the total classes that data belongs to. In addition, ∇F(ωtc,xy=i) denotes the gradients on the data whose class is i under current model ωtc. Similarly, we can rewrite the federated model update using FedSGD:

ωt+1f=ωtf-η∑k=1K∑i=1Cpk(y=i)∇Fk(ωk,tf,xy=i)    (4)

Where pk denotes the data distribution on device k, and ∇Fk(ωk,tf,xy=i) is the gradients on data which belongs to class i under current local model of device k. The superscript of weight ω denotes different settings, that is c denotes centralized setting and f denotes federated learning setting. To more intuitively compare the model update in two settings, we replace the centralized scenarios with multiple devices with the same data distribution as population distribution, and the number of devices is equal to the number of selected devices in distributed scenarios. The model update in such scenarios is the same as that in centralized scenarios because each device has the same data distribution as population distribution, and the model update can be expressed as:

ωt+1c=ωtc-η∑k=1K∑i=1CP(y=i)∇Fk(ωk,tc,xy=i)    (5)

Therefore, the difference between the federated model and the centralized model, which is inside the numerator part of model divergence, can be rewritten as:

ωt+1f−ωt+1c=Δωt+η∑​Kk=1∑​Ci=1P(y=i)∇Fk(ωk,tc,xy=i)−pk(y=i)∇Fk(ωk,tf,xy=i)    (6)

where Δωt=ωt+1f-ωt+1c. The above equation illustrates the instability in convergence and low performance of federated learning methods when the statistical heterogeneity leads to uncertain distribution across the devices and thus model divergence increases. To evaluate the model divergence caused by statistical heterogeneity across the devices, EMD can be applied. EMD is a method of calculating divergence by computing the distance between two distributions and Zhao et al. (2018) found model divergence caused by non-IID data can be evaluated with the EMD between the data distribution across devices and population distribution but specific quantitative relation is not given. As EMD denotes the distance between two probability distributions, it can be expressed as the following equation:

Dk=EMD(pk,P)=||∑i=1Cpk(y=i)-P(y=i)||    (7)

A potential problem of the EMD metric is that this metric is not invariant with respect to the automorphism. When the comparison of distributions with various number of labels and the order of labels are different, the EMD will be different. In our method, we quantified the weight divergence by the EMD between the distribution over classes on each device and the population distribution, the data labels across devices are the subset of the global data labels. Thus, the EMD between the data distribution across devices and population distribution is invariant with label alignment. Even if we need to compute EMD of different clients' data distribution, we can also predefine a label order on the central server to obtain the invariant EMD metric. Through this simple method, the EMD between the clients' data distribution and population distribution is a constant. So, we don't have to consider penalization of invariance across different environments. With Equation (7), we can further obtain the index of statistical heterogeneity by introducing EMD into the next stage of derivation. Furthermore, we also propose a dynamic weight aggregation algorithm to compute the corresponding weights of model averaging to constrain model divergence.

3.3. ISH and weighted averaging

To address the influence of statistical heterogeneity, we respectively multiply the model of each device with an index called ISH which reflects their local statistical heterogeneity, and the model update in distributed settings can be rewritten as:

ωt+1f=ωtf-η∑k=1KISHk*∇Fk(ωk,tf)    (8)

Since ωt+1c is determined using SGD as population distribution is known, the optimizing object to minimize model divergence can be expressed as:

min||η∑k=1K(∇Fk(ωk,tc)-ISHk*∇Fk(ωk,tf)+1K*ηΔωt)||    (9)

Based on the idea of greedy algorithm, we can optimize (Equation 9) by minimizing each part of it, which is:

||∇Fk(ωk,tc)-ISHk*∇Fk(ωk,tf)+1K*ηΔωt||k=1,2,...,K→0    (10)

Therefore, the index of statistical heterogeneity in device k can be calculated by the following formula:

ISHk=||∇Fk(ωk,tc)+1K*ηΔωt||||∇Fk(ωk,tf)||    (11)

Based on formula (6) and (7), the index ISHk can be further calculated as:

ISHk=1-1KDk1+Dk    (12)

After the ISH of each selected device k is obtained, they are respectively transmitted to the central server along with the local model by each device. Then the weights of each local model are calculated by executing the normalization of indexes to make sure the sum of weights is equal to 1:

αk←ISHk/∑k=1KISHk    (13)

Finally, the central server executes weighted model aggregation following formula (8), and returns the aggregated model to each selected device for a new round of federated local model training, which is also the end of a communication round.

3.4. Algorithm implementation

After deriving the statistical heterogeneity index ISH, we will describe our DWFed algorithm in detail.

The DWFed algorithm conducted by multiple rounds of communication among central server and clients. A complete communication round includes local data training, aggregation weight calculation, model and weight transmission, model aggregation, and model distribution. The complete pseudo-code of DWFed is given in Algorithm 1. At the beginning of the DWFed, the central server first initializes the weights and distributes the weights to a randomly selected set of clients. After receiving the weights, the clients first calculate ISH according to formula (12), then each clients execute one round of SGD locally, and finally clients transmit the updated weights and ISHk to the central server. The central server calculates the aggregation weights αk based on the parameters uploaded by the client and completes the model aggregation. This is the whole process of a round of communication, and the algorithm keeps repeating until prescribed communication rounds are met.

www.frontiersin.org

留言 (0)

沒有登入
gif