Generative models improve fairness of medical classifiers under distribution shifts

Our research complies with all relevant ethical regulations. We only repurposed existing assets and datasets and did not collect new assets for the purposes of our study, beyond annotations by dermatology experts for the generated images. The non-accessible data used in the study can be used for research purposes without further scrutiny or collection of consent from the source individuals.

Datasets

In this section, we describe the datasets we used to train the downstream classifiers and diffusion models across the different modalities and medical contexts. Three different datasets were used, all of which are de-identified; informed consent was obtained from the participants in the original studies that collected these data.

Histopathology

We used data from the CAMELYON17 challenge21 that include labeled and unlabeled data from three different hospitals for training, as well as one in-distribution and one OOD validation hospitals. Data from the different hospitals differ because of the staining procedure used. The task was to estimate the presence of breast cancer metastases in the images, which are patches of whole-slide images of histological lymph node sections. The number of samples per hospital is given in Extended Data Table 1a; all subsets were approximately evenly split into those containing tumors and those that did not. We used the training data (302,436 examples) and the unlabeled data (1.8 million examples) to train the diffusion model. We performed patch-based instead of whole-slide classification to align with the WILDS challenge22 and follow-up works that evaluated methods on the same setup.

In terms of label distribution, there were 151,046 patches of healthy tissue in the training set and 151,390 patches of cancerous tissue. For the ID (validation) dataset, these statistics are 16,952 and 16,608, respectively, while in the OOD (validation) and OOD (test) splits there were 17,452 and 42,527 patches corresponding to each class, respectively (that is, both OOD datasets were perfectly balanced).

Chest radiology

We trained the cascaded diffusion and downstream discriminative model on a total of 201,055 samples from the CheXpert database23, with 119,352 individuals annotated as male and 81,703 as female (the dataset only contained binary gender labels). We show the age and original label distribution in Extended Data Fig. 3a,b. The original CheXpert training set contained positive, negative, uncertain and unmentioned labels. The uncertain samples were not considered when learning the diagnostic model, but they were used to train the diffusion model. The unmentioned label was considered a negative (that is, the condition was not present), which yielded a highly imbalanced dataset. The evaluation National Institutes of Health dataset24 denoted as OOD consisted of 17,723 individuals, out of which 10,228 were male and 7,495 were female.

Extended Data Fig. 3c,d illustrates how often different conditions co-occurred in the training and evaluation samples. Capturing the characteristics of a single condition can be challenging because they frequently coexist with other conditions in a single case. One characteristic example is pleural effusion, which was included in the diagnosis of atelectasis, consolidation and edema in approximately 50% of cases. However, the scenario is slightly different for the OOD ChestX-ray14 dataset, where for most pairs of conditions the corresponding ratio was much lower.

Dermatology

The imaging samples in the dermatology dataset were often accompanied with metadata that include attributes like biological sex, age and skin tone. Skin tone was labeled according to the Fitzpatrick scale, giving rise to six categories (plus unknown). The ground truth labels for the condition were the result of aggregation of clinical assessments by multiple experts, who provided a list of top-3 conditions along with a confidence score (between 1 and 5). A weighted aggregate of these labels gave rise to soft labels that we used for training the generative and diagnostic models. The dermatology datasets were characterized by complex shifts with respect to each other as the label distribution, demographic distribution and capture process may all vary across them. To demonstrate the severity of the prevalence shift across locations, we visualized the distribution of conditions in the evaluation datasets in Extended Data Fig. 4.

To disentangle the effect of each of those shifts, we artificially skewed the source dataset along three sensitive attribute axes: sex, skin tone and age. Skewing the dataset allowed us to understand which methods performed better as the distribution shifts became more severe. For example, if our original dataset was skewed toward younger age groups, conditioning the generative model on age and (over)sampling from older ages could potentially help close the performance gap between younger and older populations. To study this aspect, we could not rebalance our datasets because we had too few samples from the long tail of our distribution with regard to the label or sensitive attribute. We skewed the training labeled dataset to make it progressively more biased (by removing instances from the least represented subgroups) and investigate how performance suffered because of skewing. For each sensitive attribute, we created new versions of the in-distribution dataset progressively more skewed to the high-data regions. We show how the resulting training dataset was skewed with respect to each of the sensitive attributes in Extended Data Table 1b–d. We also report similar demographic statistics for the three evaluation datasets in Extended Data Table 1e–g. The cascaded diffusion model was always trained on the union of the labeled training data and the total of unlabeled data across the three available domains. The discriminative model was always evaluated on the same three evaluation datasets (one in-distribution held-out dataset and two OOD datasets) for consistency.

Related workLearning augmentations with generative models in health

Generative models, especially generative adversarial networks (GANs)29, have been used by several studies to improve performance in different medical imaging tasks30,31,32,33,34 and, in particular, for underrepresented conditions35. Data obtained by exploring different latent image attributes through a generative model have also been shown to improve adversarial robustness of image classifiers36. In the clinical setting, GANs have been used by several studies to improve performance in different tasks, for example, disease diagnosis, in scenarios where few labeled samples were available. Such models have been used to augment medical images for liver lesion classification30, classification of diabetic retinopathy from fundus images31 and breast mass diagnosis in mammography32. In dermoscopic imaging33, a progressive generative model was introduced to produce realistic high-resolution synthetic images, while34 focused on improving balanced multiclass accuracy and, in particular, sensitivity for high-risk underrepresented diagnostic labels like melanoma37. It focused on a similar approach for chest X-rays by combining real and synthetic images generated with GANs to improve classifier accuracy for rare diseases35. It used conditional image generation in scenarios where the conditioning vector was not always available to disentangle image content and image style information. They applied the method to dermoscopic images (HAM10000 dataset) corresponding to seven types of skin lesions and lung computed tomography scans from the Lung Image Database Consortium-Image Database Resource Initiative.

Apart from whole-image downstream tasks, GAN-based augmentation techniques have been used to improve performance on pixel-wise classification tasks, for example, vessel contour segmentation on fundus images38 and brain lesion segmentation39. Given that pixel-wise downstream tasks were not within the scope of our study, we refer the reader to a more thorough review of GAN-based methods in medical image augmentation by Chen et al.40; Bissoto et al.41, in turn, provide an overview of GAN-based augmentation techniques with a main focus on skin lesion augmentation and anonymization.

Despite the wide variety of health applications that have adopted GAN-based generative models to produce learned augmentations of images, these are often characterized by limited diversity and quality42. More recently, DDPMs19,20,43,44,45 presented an outstanding performance in image generation tasks and have been probed for medical knowledge by Kather et al.46 in different medical domains. Other works extended diffusion models to three-dimensional magnetic resonance and computed tomography images47 and demonstrated that they can be conditioned on text prompts for chest X-ray generation48. Given the ethical questions around the use of synthetic images in medicine and healthcare46,49, it is important to make a distinction between using generative models to augment the original training dataset and replacing real images with synthetic ones, especially in the absence of privacy guarantees. None of these works claimed that the latter would be preferable, but rather came to the rescue when obtaining more abundant real data is either expensive or not feasible (for example, in the case of rare conditions), even if this solution is not a panacea42. While some studies view generative models as a means of replacing real data with ‘anonymized’ synthetic data, we abstain from such claims because greater care needs to be taken to ensure that generative models are trained with privacy guarantees, as shown by Carlini et al.50 and Somepalli et al.51.

Exploring fairness in health

Many scholars recently scrutinized ML systems and surfaced different types of biases that emerge through the ML pipeline, including problems due to data acquisition protocols, flawed human decision-making, missing features and label scarcity52. They identified and characterized various biases that can emerge during model development and are exacerbated during model deployment, and in clinical interactions, while they argued that ensuring fairness in those contexts is essential to advance health equity. The relevant literature discussed below was inspired by the realization that, if we break down performance of automated systems that rely on ML algorithms (for example, computer vision, judicial systems) based on certain demographic or socioeconomic traits, there can be vast discrepancies in predictive accuracy across these subgroups. This is alarming for applications influencing human life and it is particularly concerning in the context of computer-aided diagnosis and clinical decision-making.

One of the first studies to dive into the effect of training data composition on model performance across the sexes when using chest X-rays to diagnose thoracic diseases was the one led by Larrazabal et al.12. They found that the prevalence of a particular sex in the training set is directly linked to the predictive accuracy of the model for the same group at the test time. In other words, a model trained on a set highly skewed toward female patients would demonstrate higher accuracy for female patients at test time compared to a counterpart trained on a male-dominated set of images. Even though this finding might not come as a surprise, one would expect that a ML model used in clinical practice across geographical locations be robust to demographic shifts of this kind. In a similar vein, Seyyed-Kalantari et al.13 further explored how differences in age, race or ethnicity, and insurance type (as a proxy of socioeconomic status) are manifested in the performance of a classifier operating on chest radiographs. A crucial finding was that the algorithm would exhibit a higher false positive rate, that is, underdiagnose ethnic minorities. These effects were compounded for intersectional identities (that is, the false positive rate was higher for Black female patients compared to Black male patients). Similar findings were reported by Puyol-Antón et al.53 in a cardiac segmentation task with respect to sex and racial biases, and by Gianfrancesco et al.54 in a different modality (electronic health records) for patients with low socioeconomic status.

Overview of methodology

The method is illustrated in Fig. 1b and leverages diffusion models to learn augmentations of the data. The approach consists of three main steps: (1) we trained a generative model given the available labeled and unlabeled data; (2) we sampled from the generative model according to a sampling strategy; (3) we enriched our original training dataset from the source (also called in-distribution) domain with the synthetic images sampled from the generative model and trained a diagnostic model (potentially for multiple labels, if more than one condition can be present at once). We treated the mixing ratio between real and synthetic as a hyperparameter in all three settings and we selected the best value based on model performance on the validation set. We provide more specific details about the experimental setting for each modality in the following section and the pseudocode for our method in Fig. 1a.

Algorithm 1: pseudocode of proposed method

Input: modality

 if Modality == "histopathology" then

   Num_labels ← 2

   A \(\in\)

 else if Modality == "radiology" then

   Num_labels ← 5

   A \(\in\)

 else if Modality == "dermatology" then

   Num_labels ← 27

   A \(\in\)

 end if

Input: \(}}}}^}\times}}}}}}}}}}^}}}}\_labels}\)

 Train diffusion model \(\hat\sim }(},Y,})\)

 if Modality \(\in\) then

   Train upsampler diffusion model \(}_}\sim }(},Y,})\)

 end if

 Sample \(}\) from \(\hat\), \(}_}\) according to a fair distribution \(\hat(Y,})\)

 We assume: \(\hat(})\sim \mathrm\), \(\hat(Y)=p(Y)\)

Output: \(}}}}}^}}}}}}}}}Y }}}}}^}}}}\_labels}\) synthetic samples

 Sample random number \(\in [\mathrm]\)

 Train diagnostic model \(d(})=\mathrm(})\) using \(}}_,_\) and mixing ratio \(a\)

 if \( < a\) then

   \(}}_,_\in }}}Y}\)

 else

   \(}}_,_\in }}}}Y }}}\)

 end if

Experimental setting for each modality Histopathology

For histopathology, we trained a diffusion model to generate images at 96 × 96 resolution, which is the smallest in comparison to the other imaging modalities. The data used to train the diffusion model consisted of labeled and unlabeled data only from the in-distribution hospitals. To condition the diffusion model, we considered either the diagnostic label (that is, cancer or no cancer) or the diagnostic label and hospital ID together. For the unlabeled data, which did not contain the diagnostic label, we padded the corresponding conditioning vector with zeros. We then sampled from the diffusion model assuming a uniform distribution across hospital IDs and preserving the diagnostic label distribution. The synthetic-to-real data ratio used in histopathology is 50:50, meaning that 50% of the total training samples corresponded to real patches and 50% to synthetic samples from the diffusion model. For the diagnostic model, we focused on a patch-based classification setup instead of whole-slide image classification. Both experimental design decisions, that is, the image resolution and the classification setup, were made to align with the WILDS challenge22 and the wealth of literature that evaluates ML methods on in-the-wild distribution shifts using the same setting55. We evaluated on the held-out in-distribution and OOD hospitals (results shown in Fig. 2).

Chest radiology

For chest radiology, we trained two diffusion models (one generating images at 64 × 64 resolution and one upsampling those generated images to 224 × 224 resolution) on labeled images from the in-distribution dataset. Therefore, in this scenario, we did not have access to any unlabeled data or data from the OOD dataset. This holds for both the diffusion models and the diagnostic model, that is, the OOD dataset was only used for evaluation. We conditioned both generative models on the diagnostic label only. While treating the synthetic-to-real data ratio as a hyperparameter, we found that training the downstream diagnostic model purely on synthetic data led to the best accuracy and fairness trade-off. We did not alter the diagnostic label distribution, that is, we used the labels of the real data to condition the diffusion models and yield a synthetic sample. In this setting, the model backbone was shared across all conditions, while a separate (binary classification) head was trained for each condition, given that multiple conditions can be present at once.

Dermatology

For dermatology, we trained two diffusion models (one generating images at 64 × 64 resolution and one upsampling those generated images to 256 × 256 resolution) on labeled images from the in-distribution dataset and unlabeled images from the in-distribution and OOD datasets. At no stage of training did we have access to labeled samples from the OOD datasets. We conditioned both generative models on the diagnostic label (padded with zeros for the unlabeled samples) or the diagnostic label and a demographic attribute. While treating the ratio of synthetic-to-real data as a hyperparameter, we found that training the downstream diagnostic model on 75% synthetic images and 25% real images yielded the best results. When we artificially skewed the dataset against certain demographic subgroups, we ensured that both the generative models and the diagnostic model had access to the same labeled examples (that is, we trained a different diffusion model for each skewed setting). When we sampled from the diffusion model, we preserved the diagnostic label distribution and assumed a uniform demographic attribute distribution.

Theoretical motivation

We motivated the use of generated data and demonstrated its utility in several toy settings, which simulate the problem of having only a few number of samples from the underlying distribution or parts of the underlying distribution. We wished to have high performance despite this lack of data. We demonstrated that even in these toy settings, synthetic data were useful.

We assumed we had a dataset \(_}=}}_,_}}_\right)\right\}}_^\) where \(}}_,_\) is an image and label pair, \(}}_\) is a list of attributes about the datapoint and ? is the number of training samples. The attributes may include attributes such as sex, skin type and age, or the hospital ID (in the case of histopathology). We had an additional dataset \(_=}}}_\right\}}_^\) of unlabeled images, ? being the number of samples, that could be used as desired. We had a generative model \(\hat\) trained with \(_}\) and \(_\) (we make \(\widetilde\) implicit in the following). We dropped the subscripts in the following for simplicity where obvious.

To achieve fairness, we assumed we had a ‘fair’ dataset \(_}}=}}_,_}}_\right)\right\}}_^\) with ? datapoints that consisted of samples from the ‘fair’ distribution \(_}}\) over which we aimed to minimize the expectation of the loss. \(_(})\) was the classifier and \(L\) the loss function (for example, binary cross-entropy). We aimed to optimize the following objective:

$$\mathop}}\limits_}}\mathop}}\limits_}}_}}}}}\left(}\left(\;_}}(}),y,}\right)\right)$$

(1)

We can decompose the data generating process into \(_}}(}|},y)_}}(})_}}(\;y)\). For example, we may have created \(_}}\) by sampling uniformly over an attribute (such as sex) and labels. We assumed that the training dataset \(_}_}}\) was sampled from a distribution \(_}\) where \(_}(}|},y)_}}(}|},y)\). When \(_}(\;y,})_}}(\;y,})\), then we have a distribution shift between the training and fair distribution (for example, the training distribution is more likely to generate images of a particular attribute or combinations of label and attribute than the fair distribution).

We aimed to combine the training dataset \(_}\) and synthetic data sampled from the generative model \(\hat\) to mimic most closely the fair distribution and improve fairness. We constructed a new dataset \(}\) according to a distribution \(}\) from these distributions using some probability parameter \(\alpha\):

$$\left(},},y\right) \sim}\left\\left(},},y\right) \sim _}\qquad:\alpha \\ \left(},},y\right),\sim \hat\left(}|y,}\right),\left(},y\right) \sim \hat(}}y)\qquad:(1-\alpha )\end\right.$$

(2)

So instead of minimizing equation (1), we minimized the following sum of expectations:

$$\mathop}}\limits_}}\alpha \mathop}}\limits_},},y\right) \sim _}}}}\left(}\left(\;_}}(}),},y\right)\right)+(1-\alpha )\mathop}}\limits_},},y\right) \sim \hat}\left(}\left(\;_}}(}),},y\right)\right)$$

(3)

The question is then how to choose \(\alpha\) and \(\hat(},y)\). For all settings in the main article, we maintained the label distribution \(\hat(\;y)=p(\;y)\) but sampled uniformly over the attribute \(\alpha\). We validated this choice on dermatology in the Supplementary Information. We treated \(\alpha\) as a hyperparameter in all settings.

ModelsUpsampler preprocessing

Whenever we required an upsampler (that is, in radiology and dermatology), we trained it by preprocessing the original images using the following steps: (1) upsampled images from the 64 × 64 input resolution to the desired output resolution with bilinear interpolation and used an anti-alias with 0.5 probability; (2) added random Gaussian noise with 0.2 probability and σ = 4.0 (in the (0–255) range); (3) applied random Gaussian blurring with a 7 × 7 kernel and σmean = 0, σs.d. = 0.2; (4) quantized the image to 256 bins; and (5) normalized the image to the (−1 to 1) range.

Dealing with missing labels

For both the generative model and the upsampler, we filled the conditioning vectors with zeros (indicating an invalid vector) for the unlabeled data. This allowed us to use classifier-free guidance20 to make images more ‘canonical’ with respect to a given label or property.

In this section, we describe the exact model architecture used for the trained diffusion models and classifiers, as well as the hyperparameters used for the presented results. Hyperparameters were selected based on the baseline model performance on the respective in-distribution validation sets and held constant for the remaining methods. This meant that we did not finetune hyperparameters for each method (other than the baseline) separately. We use the DDPM as presented by refs. 19,20,43 for the generation and the upsampler (only the radiology and dermatology datasets required higher-resolution images). The backbone model was always a UNet architecture. The hyperparameters used for the cascaded diffusion models were based on the standard values mentioned in the literature with minimal modifications. We present all hyperparameters in Extended Data Table 2.

Standard augmentations Histopathology

For this modality, augmentations included brightness, contrast, saturation and hue jitter. Hue and saturation were sufficient to achieve the high-quality results described by Tellez et al.56.

Chest radiology

The heuristic augmentations considered for this modality included: random horizontal flipping; random cropping to 202 × 202 resolution; resizing to 224 × 224 with bilinear interpolation and anti-alias; random rotation by 15 degrees, shifting luminance by a value sampled uniformly from the (−0.1 to 0.1) range; and shifting contrast using a value uniformly sampled from the (0.8 to 1.2) range (that is, pixel values were multiplied by the shift value and clipped to remain within the (0 to 1) range).

Dermatology

For this modality, we used the following heuristic augmentations: random horizontal and vertical flipping; adjusting image brightness by a random factor (maximum \(\delta =0.1\)); adjusting image saturation by a random factor (within the (0.8 to 1.2) range); adjusting the hue by a random factor (maximum \(\delta =0.02\)); adjusting image contrast by a random factor (within the (0.8 to 1.2) range); random rotation within the (−150 to 150) range; and random Gaussian blurring with standard deviation uniformly sampled from the following values: .

Baselines

In all contexts, we considered the strongest heuristic augmentations as a baseline. These augmentations (heuristic or learned) can be combined with any alternative learning algorithm that aims to improve model generalization. For the sake of our experiments, we used empirical risk minimization57 because there is no single method that consistently outperforms it under distribution shifts55. Even though our experiments and analysis focus on DDPMs for generation, any conditional generative model that produces high-quality and diverse samples can be used. In general, the risk, that is, how well the algorithm will fit the data, cannot be computed on the true data distribution \(P(x,y)\) because it is unknown to the learning algorithm. However, we could compute an approximation, called empirical risk, by averaging the loss function on the training set samples.

Histopathology

For this modality, all models used the same ResNet-152 backbone. We compared (1) a baseline using no augmentation (Baseline) and (2) one using standard color augmentations (Color augm.) as applied in standard ImageNet training. This augmentation included brightness, contrast, saturation and hue jitter. Hue and saturation were sufficient augmentations to achieve the highest-quality results by Tellez et al.56; hence, we did not evaluate other heuristic augmentations. Our baseline did not use pretraining because it previously did not yield any benefits on this particular dataset as reported by Wiles et al.55. We also compared the models to those applying heuristic color augmentations on top of the synthetic data.

Chest radiology

All models used the same BiT-ResNet-152 backbone58. We considered baselines that use (1) different pretraining, (2) different heuristic augmentations and combinations thereof, and (3) focal loss. We investigated using JFT59 and ImageNet-21K60 for pretraining to explore how much different pretraining datasets impacted the final results. We investigated using RandAugment61, ImageNet Augmentations as described above, and RandAugment + ImageNet Augmentations to determine how much performance we could gain by using heuristic augmentations. Finally, we considered using focal loss62, which was developed to improve performance on imbalanced datasets.

Dermatology

All models used the same BiT-ResNet backbone58. We considered baselines that (1) used different pretraining, (2) used different heuristic augmentations, (3) resampled the dataset and (4) used the focal loss. We investigated using JFT59 and ImageNet-21K60 for pretraining. We investigated using RandAugment61, ImageNet Augmentations and RandAugment + ImageNet Augmentations. We then resampled the dataset so that the distribution over attributes was even (we upsampled samples from low-data regions so that they occurred more frequently in the dataset). Finally, we considered using focal loss62, which was developed to improve performance on imbalanced datasets.

Evaluation detailsExperimental setup

To account for potential variations with respect to model initialization, we evaluated all versions of our model and baselines with five different initialization seeds and report the average and standard deviation across those runs for all metrics. We ran all experiments on tensor processing units.

Fairness metrics

Different definitions of fairness have been proposed in the literature, which are often at odds with each other63. In this section we discuss our choice of fairness metrics for each modality. In histopathology, we used the gap between the best and worst performance among the in-distribution hospitals. For radiology, we considered AUC parity, namely the parity of the area under the ROC for different demographic subgroups identified by the sensitive attribute \(A\), which can be seen as the analog of equality of accuracy64. Therefore, for this modality, we report the AUC gap between males and females in Fig. 3a. We considered this most relevant given that the positive and negative ratio of samples across all conditions was very imbalanced.

In dermatology, we report the gap between the best and worst subgroup performance, where subgroups are defined based on the sensitive attribute axis under consideration in Fig. 4. We also report the central best estimate for the a posteriori estimate of performance (that is, top-3) difference between a group and its outgroup. The steps to obtain the values plotted in Supplementary Fig. 7 are the following: (1) we defined a group (and its matching outgroup) as the set of instances characterized with a particular value of a sensitive attribute A = α, that is, group =  and group = . Here A ⊆ ; (2) we assumed a uniform Beta distribution Beta(1,1) as a prior for the performance difference between top3group and top3outgroup and fitted this to the observed data; (3) we sampled n = 100,000 samples from the estimated posterior differences between tôp3group and tôp3outgroup and report the spread, that is, the standard deviation of the maximum a posteriori estimates, which can be interpreted as the central best estimate for fairness.

Setup for distribution shift estimation

We computed domain mismatches considering the space where decisions are performed, that is, the output of the penultimate layer of each model. Thus, we projected each data point from the input space of size \(}}^\) to a representation of size \(}}^\) and then computed the maximum mean discrepancy (MMD) between two distributions (that is, datasets). Given two distributions \(U\) and \(Z\), their respective samples

留言 (0)

沒有登入
gif