Factuality Classifier for Medical RAG (PubMedBERT - Source Based)
Model Description
This model is a fine-tuned version of PubMedBERT (microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext), designed to classify the factuality of medical documents. It was developed as a factuality estimation component for FRAG (Factuality-aware Retrieval-Augmented Generation), acting as a safeguard against medical misinformation in LLM-based question answering.
Unlike the Claim-Based variant, this Source-Based classifier was trained to identify the latent linguistic and structural patterns associated with the reputation of the publishing source. It predicts whether a given medical text likely originates from a verified, reliable publisher or an unreliable one.
- Model type: Text Classification (Binary:
1= Factual,0= Non-Factual) - Language(s): English
- License: MIT
- Base Model:
microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext
Intended Uses & Limitations
Intended Use: This model is intended to be used as a filtering or re-ranking component within a Retrieval-Augmented Generation (RAG) pipeline in the healthcare domain. It assesses the reliability of a retrieved document based on the stylistic and semantic markers typical of trustworthy medical sources.
Limitations:
- The model relies on source reputation as a proxy for factuality. While highly effective, it is a heuristic approach: it may occasionally misclassify a factually correct article if it is written in a style typical of unreliable sources, or vice versa.
- This is a research prototype and should not be used for autonomous medical diagnosis or clinical decision-making without expert supervision.
Training Data
The model was fine-tuned on a custom "Source-Based Dataset" derived from the Monant Medical Misinformation Dataset. Specifically, the data relies on Monant's Type 1 annotations, which explicitly label the validity of the publishing source (reliable vs. unreliable) based on assessments from expertly-curated platforms and human fact-checking organizations (e.g., Media Bias/Fact Check).
To prevent predictive bias toward the majority class, a perfectly balanced subset of the data was strategically sampled. The final dataset consists of 19,786 perfectly balanced articles:
- Training Set: 12,662 articles (6,331 Reliable, 6,331 Unreliable)
- Validation Set: 3,166 articles (1,583 Reliable, 1,583 Unreliable)
- Test Set: 3,958 articles (1,979 Reliable, 1,979 Unreliable)
Training Procedure
The model was trained using TensorFlow/Keras with the following hyperparameters:
- Max Sequence Length: 512 tokens
- Batch Size: 4
- Learning Rate: 2e-5
- Optimizer: Adam
- Loss Function: Sparse Categorical Crossentropy
- Epochs: 20 (with Early Stopping monitoring
val_loss, patience=5, restoring best weights)
Note: The model exhibited extremely rapid convergence. During training, the validation loss reached its minimum during the first epoch and then rose sharply, indicating a rapid onset of overfitting as the model quickly adapted to the task. The Early Stopping mechanism successfully intervened at the end of the second epoch, restoring the best-performing weights.
Evaluation Results
Evaluated on the completely unseen test set of 3,958 articles, the model demonstrated an exceptional ability to separate reliable from unreliable medical texts in the semantic space. The perfectly symmetric distribution of errors indicates that the model is well-balanced and does not suffer from class bias.
- Test Accuracy: 96.00%
Confusion Matrix Breakdown:
- True Positives (Predicted Factual, Actual Factual): 1918
- True Negatives (Predicted Non-Factual, Actual Non-Factual): 1913
- False Positives (Predicted Factual, Actual Non-Factual): 66
- False Negatives (Predicted Non-Factual, Actual Factual): 61
How to Get Started with the Model
You can easily load and use this model via the Hugging Face transformers library. The model takes the raw HTML-cleaned text of the medical article as input.
Using TensorFlow (as trained):
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("tommibazzo01/factuality-classifier-pubmedbert-SourceBased")
model = TFAutoModelForSequenceClassification.from_pretrained("tommibazzo01/factuality-classifier-pubmedbert-SourceBased")
# Example medical article text
article_text = "Extensive scientific research has proven that there is no link between vaccines and autism."
# Tokenize
inputs = tokenizer(
article_text,
return_tensors="tf",
truncation=True,
padding=True,
max_length=512
)
# Predict
outputs = model(inputs)
predicted_class = tf.argmax(outputs.logits, axis=1).numpy()[0]
labels = {0: "Non-Factual", 1: "Factual"}
print(f"Prediction: {labels[predicted_class]}")
- Downloads last month
- 48