PrescDRL: deep reinforcement learning for herbal prescription planning in treatment of chronic diseases

Clinical sequential data of diabetes

In this section, we present a benchmark dataset of clinical sequential diagnosis and treatment for diabetes, which serves as an example to train the optimization method of DDTS. (Ethics approval of this study has been obtained from ethics committee of institute of Clinical Basic Medicine of Traditional Chinese Medicine (NO. 2016NO.11-01)). In this dataset, the symptom observations of patients are selected as the states, and herbal prescription prescribed by doctors as actions in reinforcement learning (RL) model.

Fig. 2figure 2

Distribution of sequential diagnosis and treatment data. A Distribution of the number of patient visits. B Distribution of the number of patients’ symptoms. C Distribution of the number of the herbs in prescriptions. D P-values distribution of different number of herbal prescriptions clusters. E P-value distribution of 30 prescription clusters. F P-value distribution of 100 prescription clusters

To construct a standard dataset for sequential decision-making in TCM, we first extracted 10,666 medical records of 2,895 diabetic patients from Guang’anmen Hospital. As depicted in Fig. 2A, 49.6% of the patients had only one medical record and each patient had an average of 3.68 medical records. For each medical record, we extracted the patient’s symptoms and an herbal prescription consisting of multiple herbs for treatment. With the exception of 334 medical records with over 40 symptoms, the number of symptoms per patient was normally distributed (Fig. 2B), with an average of 10.386 symptoms per medical record. Similarly, the number of herbs per prescription was normally distributed (Fig. 2C), with an average of 10.059 herbs per prescription. We screened 1459 patients with more than one medical visit and obtained 5,638 medical records, which were arranged into diagnosis and treatment sequences based on clinic time.

A deep reinforcement learning framework to optimize herbal prescription treatment planning

The optimization of DDTS is essentially a Markov Decision Process (MDP, [46]). To tackle this problem, we propose an optimization model for herbal prescription treatment planning based on two high-performance deep RL models, namely, DRN [47] and DRQN [48]. DQN is a combination of Q-learning and convolutional neural network that can perform RL tasks. On the other hand, DRQN first extracts features using two fully connected layers, followed by a LSTM layer, and then predicts the action value using a final fully connected layer (Fig. 1C).

In the RL framework, the agent acts as a virtual intelligent doctor, with the patient’s state serving as the environment and prescribing herbal medication to the patient as the agent’s action. The key components of RL models are defined as follows: 1) The state space is denoted by S, where a state \(s \in S\) represents the observation of a patient’s symptoms; 2) The action space is denoted by A, where an action \(a\in A\) represents the herbal medication prescribed to the patient; 3) The reward function is denoted by R(s, a), which returns a reward after the agent takes action in state s; 4) The virtual environment is denoted by E, which is an offline virtual environment based on sequential clinical data; 5) The state transition is denoted by T, where each transition is obtained using a prediction strategy.

The state observations of patients

In the DDTS optimization problem, the patient’s state is a key component of the reinforcement learning model. In TCM clinics, doctors obtain symptom descriptions of patients through “seeing, hearing, asking, and cutting” “辨证论治”, summarize the syndrome type, and prescribe appropriate treatments. However, since the true state of the patient is not available, even experienced doctors cannot fully determine the specific conditions of patients. Therefore, the patient’s symptoms observed by the doctors are used to approximate the patient’s state.

In the diabetes dataset, the distribution of patient symptoms (Fig. 2B) shows that the number of symptoms varies among patients (average of 10 symptoms per patient). The core symptoms for each disease typically differ, and different symptoms may have varying importance. However, it is challenging to obtain a precise symptom grading for diabetes, and thus different symptoms are typically considered to have equal weight. As a result, a patient’s state is represented by a symptom vector, where the symptoms present in the patient are marked as 1 and those that are not are marked as 0.

The action spaces of virtual doctor

In TCM diagnosis and treatment, doctors prescribe herbal prescriptions based on the patient’s symptoms. From all the medical records, we obtained 9,695 distinct herbal prescriptions. Considering all these prescriptions as actions of the RL model would greatly increase the difficulty of training and convergence of RL algorithms due to the large number of actions. Therefore, it is necessary to reduce the number of actions by converting prescription numbers into a suitable discrete space. This will reduce the model complexity and improve the convergence speed.

To reduce the number of prescriptions, we employed the K-means clustering algorithm[46] to cluster these prescriptions and used prescription’s herb information as the feature. We performed a parameter tuning experiment to obtain a proper number of herbal prescription clusters (HPC) which is considered a hyperparameter. We tested different values of HPC ranging from 30 to 150 with increments of every 10 categories. A good HPC result is expected to have different categories with significantly different herbs. To achieve this, we used the Chi-square test[49] to calculate the statistical difference between any two clusters based on the composition of herbs prescribed in different clusters. The resulting HPC is used as the action of the RL models.

Design of reward function

The aim of RL-based DDTS optimization is to use a vast number of medical records to predict the optimal sequence of herbal prescriptions for patients. The objective is not only to maximize the treatment effect but also to ensure that the predicted prescriptions are reasonable. This means that the efficacy of the predicted prescriptions should be within a reasonable range, and they should not have side effects on patients or contradict drug indications.

Due to the absence of curative effect evaluation data in the diabetes dataset, we utilized the change in symptom scores between two consecutive visits before and after treatment as the immediate reward value for the current patient action. The symptom score is used to evaluate the severity of the patient’s disease state and is supposed to be the weighted sum of all the patient’s symptoms (the weight indicates the importance of the symptoms). However, the weight of symptoms is difficult to define, so we set the weight of all symptoms to 1, then the symptom score is simply defined as the number of symptoms of the patient. For example, the patient has 5 symptoms, then the symptom score is 5. Additionally, we calculated the Jaccard coefficient to measure the similarity between the predicted action and the actual prescription provided by the doctor. A higher reward value was assigned to actions that had a higher similarity to the doctor’s prescription. Therefore, the reward function was formulated as follows:

$$\begin \mathcal (s, a)=\gamma \sum _^ \alpha _\left( s_-s_^\right) +\beta J a c\left( a, a^\right) \end$$

(1)

$$\begin }\left( a, a^\right) =\frac\right| }\right| } \end$$

(2)

where \(\alpha _i\) represents the weight of patient’s \(i-th\) symptom, \(s_i\) represents the \(i-th\) symptom of the patient at the current visit, and \(s_i^\) represents the \(i-th\) symptom of the patient at the next visit. \(\gamma\) denotes the weight of therapeutic effect of patients, \(\beta\) denotes the weight of risk, a denotes the prescription given by the doctor, and \(a^\) denotes the prescriptions predicted by the model.

Virtual environment construction

To overcome the impossibility of training the proposed DDTS optimization model in the real diagnosis and treatment process, we developed an off-line virtual environment based on the available medical records of patients. We constructed a tetrad, represented as \((s_1,a,r,s_2)\), using the symptom observation and prescriptions of each patient in the current and next diagnosis and treatment. In this tetrad, \(s_1\) denotes the current symptom observation of the patient, denotes the action based on \(s_1\), r denotes the reward received after performing the action a, and \(s_2\) denotes the new symptom observation of the patient after the action a. We obtained 4,179 tetrads from the medical records, which served as a virtual environment to train the deep RL model.

State transition prediction and termination

In the optimization of DDTS with the deep RL model, one of the main challenges is obtaining the next symptom observations after conducting an action based on the current symptoms due to the lack of tetrads constructed in the training stage. To address this issue, we utilized the state transition network, which includes states and actions (Fig. 1D), to predict the patient’s symptoms after treatment. Specifically, we developed a prediction strategy that involves screening out all tetrads \((s_1,a,r,s_2 )\) with the same predicted action in the training set, calculating the Jaccard similarity between \(s_1\) the symptom observations in each tetrad and the current symptoms, and selecting the \(s_1\) tetrad with the highest similarity to the current symptoms. Finally, \(s_2\) in the same tetrad as \(s_1\) is selected as the state of the patient after treatment.

The distribution of symptoms in patients (Fig. 2B) indicates that 94.7% of patients have between 1 and 20 symptoms. Based on TCM expert recommendations and the symptom distribution, we define the first sequence termination condition (STC) as a patient’s symptom score \(\le\)3. According to the evaluation criteria of diabetes treatment effect, a 30% reduction in symptom score is considered effective, while a 70% reduction is considered markedly effective. Therefore, the second STC is defined as a 60% reduction in the patient’s symptom score. The distribution of consultations (Fig. 2A) shows that 93% of patients have between 1 and 10 consultations (average number is 3.7). The last STC is number of iterations bigger than 15.

A multi-layer neural network for herbal prescription recommendation

In clinical practice of TCM, the ultimate goal of intelligent decision-making for diagnosis and treatment is to recommend effective herbal prescriptions to patients. By utilizing the trained deep RL models, we can obtain the sequential HPC for patients. In order to predict appropriate prescriptions, we model the prescription recommendation as a task of multi-label prediction. To achieve this, we constructed a multi-layer neural network (i.e., multi-layer perception), which takes the patient’s symptoms and the HPC predicted by the RL models as input features, and outputs the predicted herbal prescription (Fig. 1B).

Experimental designParameter setting

In the DDTS optimization experiment, we used a total of 1,495 patient samples, of which 80% (1,203 samples) were used for training, and the remaining 20% (i.e., 292 samples) were used for testing. Similarly, there are also 80% samples for training and 20% for testing in the experiment of prescription recommendation.

In our proposed PrescDRL, the DQN network framework consists of three fully connected (FC) layers with 400, 300, and 30 neurons, respectively. For the DRQN network, the first two layers are FC layers with 300 and 512 neurons. The middle layer is an LSTM layer with 512 neurons, and the final layer is a FC layer with 30 neurons. Since there are 30 well-tuned HPCs, which correspond to 30 actions in modeling RL models, the DQN and DRQN layers have 30 neurons. During the training of these two models, the learning rate is 0.01, the discount coefficient of the reward value is 0.9, the random exploration probability is 0.1, and the batch size is 32. The parameters are copied to the Q-target network every 100 training batches.

Fig. 3figure 3

Experimental results of PrescDRL. A Performance comparison of single-step return. B Performance comparison of single cumulative return. C Performance comparison of multi-step cumulative return. D Comparison of iterations, i.e., the sequence length of diagnosis and treatment plan. E Reward distribution of \(PrescDRL_\) with different HPCs. F Reward distribution of \(PrescDRL_\) with different HPCs. G Precision comparison of prescription recommendation. H Recall comparison of prescription recommendation. I F1-score comparison of prescription recommendation. J LoU comparison of prescription recommendation. K Number distribution of the predicted herbs given by PrescDRL that considers symptom and scheme. L Number distribution of the predicted herbs given by PrescDRL that only considers symptom

Evaluation metrics

In clinical practice, evaluating the effectiveness of TCM treatment for chronic diseases, such as diabetes, can be challenging due to the long duration of treatment and the unsuitability of western medicine’s clinical mortality as an evaluation metric[50]. In this study, we evaluated the performance of DDTS optimization results based on the improvement of symptom score, which is represented as the return values of RL models. To assess the effectiveness of the optimization models, we considered three commonly used metrics: single-step return (SSR), single-step cumulative return (SCR), and multi-step cumulative return (MCR). For SSR, the optimization models are trained based on the symptom observations of the first visit of each patient, and the differential between the symptom observations before and after the models provide an HPC is defined as SSR. In contrast, SCR considers all visits of each patient, and the average of all returns is computed. MCR is a more comprehensive metric, where the models predict an HPC based on the first visit of each patient and then use the state transition function to generate follow-up symptom observations until a set stopping condition is reached.

In addition, predicting prescriptions is considered a multi-label classification problem, for which precision, recall, F1 score, and IoU are used as evaluation metrics.

留言 (0)

沒有登入
gif