Review — G-BERT: Pre-training of Graph Augmented Transformers for Medication Recommendation

G-BERT, Using Ontology Information and BERT

Sik-Ho Tsang
6 min readOct 26, 2023

Pre-training of Graph Augmented Transformers for Medication Recommendation
G-BERT
, by IQVIA, IBM Research AI, Georgia Institute of Technology,
2019 ICJAI, Over 190 Citations (Sik-Ho Tsang @ Medium)

Medical LLM
2020 [BioBERT] [BEHRT] 2021 [MedGPT] 2023 [Med-PaLM]
==== My Other Paper Readings Are Also Over Here ====

  • G-BERT is proposed, which combines GNN and BERT for medical code representation and medication recommendation.
  • Formally, GNN representation is input into BERT for pretraining, then fine-tuned for downstream predictive tasks on longitudinal EHRs from patients with multiple visits.

Outline

  1. Problem Formulation
  2. G-BERT
  3. Results

1. Problem Formulation

Left: ICD-9 Ontology. Right: Notations.

1.1. Longitudinal EHR Data

  • In longitudinal Electronic Health Record (EHR) data, each patient can be represented as a sequence of multivariate observations:
  • where n is from 1 to N. N is the total number of patients; T(n) is the number of visits of the n-th patient. Here two main medical codes are chosen to represent each visit Xt = CtdCtm of a patient which is a union set of corresponding diagnoses codes Ctd ⊂ Cd and medications codes Ctm Cm.
  • For simplicity, Ct* is used to indicate the unified definition for different type of medical codes and the superscript (n) is dropped for a single patient whenever it is unambiguous.
  • C* denotes the medical code set and |C*| the size of the code set. c* C* is the medical code.

1.2. Medical Codes

  • Medical codes are usually categorized according to a tree-structured classification system such as ICD-9 ontology for diagnosis and ATC ontology is for medication. Od, Om to denote the ontology for diagnosis and medication. O* to indicate the unified definition for different type of medical codes.
  • Two functions pa(.), ch(.) are defined, which accept target medical code and return ancestors’ code set and direct child code set.

1.3. Problem Definition (Medication Recommendation)

  • Given diagnosis codes Ctd of the visit at time t, patient history X1:t:

we want to recommend multiple medications by generating multi-label output ^yt ∈ {0, 1}^|Cm|.

2. G-BERT

G-BERT

Conceptually, an enhanced ontology embedding by GNN is input to BERT for pretraining. Then, BERT is fine-tuned for downstream task.

2.1. Ontology Embedding

  • Ontology embedding is constructed from diagnosis ontology Od and medication ontology Om. Since the medical codes in raw EHR data can be considered as leaf nodes in these ontology trees, the medical code embedding can be enhanced using graph neural networks (GNNs) to integrate the ancestors’ information of these codes.
  • A two-stage procedure is performed with a specially designed GNN for ontology embedding.
  • To start, an initial embedding vector is assigned to every medical code c* ∈ O* with a learnable embedding matrix We with d is the embedding dimension.

Stage 1: For each non-leaf node c*, its enhanced medical embedding hc*:

  • where g( , , ) is an aggregation function which accepts the target medical code c*, its direct child codes ch(c*) and initial embedding matrix.

Intuitively, the aggregation function can pass and fuse information in target node from its direct children which result in the more related embedding of ancestor’ code to child codes’ embedding.

Stage 2: After obtaining enhanced embeddings, the enhance embedding matrix He is used to pass back to get ontology embedding for leaf codes oc*:

  • g( , , ) can be as simple as sum or mean. Here, g( , , ) is defined as follows (taking stage 2 for an example):
  • where || represents concatenation which enables the multihead attention mechanism, σ is the activation function, Wk is the weight matrix, αki,j are the corresponding k-th normalized attention coefficients computed:

2.2. Pretraining G-BERT

The model takes the above ontology embedding as input and derive visit embedding vt* for a patient at t-th visit:

  • where [CLS] is a special token as in BERT. It is put in the first position of each visit of type *.
  • One big difference between language sentences and EHR sequences is that the medical codes within the same visit do not generally have an order, so the position embedding is removed.

2.2.1. Self-Prediction Task

  • This task is to recover the visit embedding v1* what it is made of, i.e., the input medical codes Ct* for each visit as follows:
  • The binary cross entropy loss Ls is minimized, and in practise, Sigmoid(f(v*)) should be transformed by applying a fully connected neural network f( ) with one hidden layer.
  • Similar to BERT, 15% codes in C* is masked randomly.

2.2.2. Dual-Prediction Task

  • Note again ICD-9 ontology is for diagnosis (d) and ATC ontology is for medication (m).
  • In medication recommendation, multiple medications can be predicted given only the diagnosis codes. Inversely, unknown diagnosis can also be predicted given the medication codes.

2.2.3. Overall Loss Function

  • Finally, the loss below is used to train on EHR data from all patients who only have one hospital visits:

2.3. Fine-Tuning G-BERT

  • The known diagnosis codes Ctd at the prediction time t is also represented using the same model as vt*.

Concatenating the mean of previous diagnoses visit embeddings and medication visit embeddings, also the last diagnoses visit embedding, an MLP based prediction layer is built on top to predict the recommended medication codes as:

  • Given the true labels ^yt at each time stamp t, the loss function for the whole EHR sequence (i.e. a patient) is:

3. Results

3.1. Dataset & Some Training Details

  • EHR data from MIMIC-III [Johnson et al., 2016] is used. The drug coding is transformed from NDC to ATC Third Level for using the ontology information. Dataset is split into training, validation and testing set in a 0.6:0.2:0.2 ratio.
  • GNN: Input embedding dimension is 75, number of attention heads is 4. BERT: hidden dimension is 300, dimension of position-wise feed-forward networks is 300, 2 hidden layers with 4 attention heads for each layer.
  • Specially, authors alternated the pre-training with 5 epochs and fine-tuning procedure with 5 epochs for 15 times to stabilize the training procedure. (This training procedure is quite special to me.)

3.2. Results

Performance
  • G-: Use medical embedding without ontology information.
  • P-: No pretraining.
  • By comparing the last 4 rows, ontology information and pretraining are both important.

The final model G-BERT is better than the attention based model, RETAIN, and the recently published state-of-the-art model, GAMENet. Specifically, even adding the extra information of DDI knowledge and procedure codes, GAMENet still performs worse than G-BERT.

--

--

Sik-Ho Tsang

PhD, Researcher. I share what I learn. :) Linktree: https://linktr.ee/shtsang for Twitter, LinkedIn, etc.