PRIMER: Pretrained RadImageNet for Mammography Embedding Representations
PRIMER is a specialized deep learning model for mammography analysis, finetuned from RadImageNet using contrastive learning on the CMMD (Chinese Mammography Mass Database) dataset. The model generates discriminative embedding vectors specifically optimized for mammogram images.
Model Overview
- Base Model: RadImageNet ResNet-50
- Training Method: SimCLR contrastive learning (NT-Xent loss)
- Architecture: ResNet-50 encoder + 2-layer MLP projection head
- Input: 224×224 RGB images (converted from DICOM grayscale)
- Output: 2048-dimensional embedding vectors
- Training Dataset: CMMD mammography DICOM files
- Framework: PyTorch 2.1+
Key Features
- Finetuned specifically for mammography imaging
- Self-supervised contrastive learning (no labels required)
- Produces embeddings with better clustering and separation than baseline RadImageNet
- Handles DICOM preprocessing pipeline end-to-end
- Supports multiple backbone architectures (ResNet-50, DenseNet-121, Inception-V3)
DICOM Preprocessing Pipeline
The model expects mammography DICOM images preprocessed through the following pipeline. This preprocessing is critical for proper model performance:
Step 1: DICOM Loading
- Read DICOM file using pydicom
- Extract pixel array as float32
Step 2: Photometric Interpretation Correction
- Check PhotometricInterpretation attribute
- If MONOCHROME1: Invert pixel values (max_value - pixel_value)
- MONOCHROME1: Higher values = darker (inverted scale)
- MONOCHROME2: Higher values = brighter (standard scale)
Step 3: Intensity Normalization
- Percentile-based clipping to remove outliers:
- Compute 2nd percentile (p2) and 98th percentile (p98)
- Clip all values: pixel_value = clip(pixel_value, p2, p98)
- Min-max normalization to [0, 255]:
- normalized = ((pixel_value - min) / (max - min + 1e-8)) × 255
- Convert to uint8
Step 4: CLAHE Enhancement
- Apply Contrast Limited Adaptive Histogram Equalization (CLAHE)
- Clip limit: 2.0
- Tile grid size: 8×8
- Improves local contrast and enhances subtle features
Step 5: Grayscale to RGB Conversion
- Duplicate grayscale channel 3 times: RGB = [gray, gray, gray]
- Required because RadImageNet expects 3-channel input
Step 6: Resizing
- Resize to 224×224 using bilinear interpolation
Step 7: Data Augmentation (Training Only)
Training augmentations:
- Horizontal flip (p=0.5)
- Vertical flip (p=0.3)
- Rotation (±15 degrees, p=0.5)
- Random brightness/contrast (±0.2, p=0.5)
- Shift/scale/rotate (shift=0.1, scale=0.1, rotate=15°, p=0.5)
Step 8: Normalization
- ImageNet normalization (required for RadImageNet compatibility):
- Mean: [0.485, 0.456, 0.406]
- Std: [0.229, 0.224, 0.225]
- Convert to tensor (C×H×W format)
Complete Preprocessing Code
import cv2
import numpy as np
import pydicom
from PIL import Image
class DICOMProcessor:
def __init__(self):
self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
def preprocess(self, dicom_path):
# 1. Load DICOM
dicom = pydicom.dcmread(dicom_path)
image = dicom.pixel_array.astype(np.float32)
# 2. Handle photometric interpretation
if hasattr(dicom, 'PhotometricInterpretation'):
if dicom.PhotometricInterpretation == "MONOCHROME1":
image = np.max(image) - image
# 3. Intensity normalization
p2, p98 = np.percentile(image, (2, 98))
image = np.clip(image, p2, p98)
image = ((image - image.min()) / (image.max() - image.min() + 1e-8) * 255)
image = image.astype(np.uint8)
# 4. CLAHE enhancement
image = self.clahe.apply(image)
# 5. Grayscale to RGB
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# 6. Resize
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
# 7. ImageNet normalization
image = image.astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = (image - mean) / std
# 8. Convert to tensor (C, H, W)
image = np.transpose(image, (2, 0, 1))
return image
Model Architecture
Overall Structure
Input DICOM (H×W grayscale)
↓
[DICOM Preprocessing Pipeline]
↓
224×224×3 RGB Tensor
↓
[RadImageNet ResNet-50 Encoder]
↓
2048-dim Embeddings
↓
[Projection Head] (training only)
↓
128-dim Projections
Components
1. Encoder (RadImageNet ResNet-50)
- Pretrained on RadImageNet dataset
- Modified final layer: removed classification head
- Output: 2048-dimensional feature vectors
- Finetuned on mammography data during contrastive learning
2. Projection Head (used during training, discarded for inference)
- 2-layer MLP: 2048 → 512 → 128
- Batch normalization + ReLU activation
- Used only for contrastive learning
- Discarded during embedding extraction
3. Loss Function: NT-Xent (Normalized Temperature-scaled Cross Entropy)
- Contrastive loss from SimCLR framework
- Temperature parameter: Ï„ = 0.07
- Cosine similarity with L2 normalization
- Positive pairs: Two augmented views of same image
- Negative pairs: All other images in batch
Training Details
Contrastive Learning Framework (SimCLR)
For each mammogram:
1. Create two different augmented views (image1, image2)
2. Pass both through encoder → projection head
3. Compute NT-Xent loss between the two projections
4. Maximize agreement between views of same image
5. Minimize similarity with other images in batch
Hyperparameters
- Batch size: 128
- Epochs: 50
- Learning rate: 1e-4 (AdamW optimizer)
- Weight decay: 1e-5
- Temperature: 0.07
- LR scheduler: Cosine annealing with 10-epoch warmup
- Mixed precision training: Enabled (AMP)
- Gradient clipping: 1.0
- Early stopping patience: 15 epochs
Training Data
- Dataset: CMMD (Chinese Mammography Mass Database)
- Training split: 70%
- Validation split: 15%
- Test split: 15%
- Total training images: ~13,000 mammograms
Model Specifications
| Property | Value |
|---|---|
| Model Type | Feature Extraction / Embedding Model |
| Architecture | ResNet-50 (RadImageNet pretrained) |
| Input Shape | (3, 224, 224) |
| Output Shape | (2048,) |
| Parameters | ~23.5M trainable |
| Model Size | 283 MB |
| Precision | FP32 |
| Framework | PyTorch 2.1+ |
Usage
Loading the Model
import torch
import torch.nn as nn
import timm
# Define the encoder architecture
class RadImageNetEncoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = timm.create_model('resnet50', pretrained=False, num_classes=0)
self.feature_dim = 2048
def forward(self, x):
return self.encoder(x)
# Load the checkpoint
checkpoint = torch.load('pytorch_model.bin', map_location='cpu')
# Extract encoder weights
model = RadImageNetEncoder()
encoder_state_dict = {
k.replace('encoder.encoder.', ''): v
for k, v in checkpoint['model_state_dict'].items()
if k.startswith('encoder.encoder.')
}
model.encoder.load_state_dict(encoder_state_dict)
model.eval()
Extracting Embeddings
# Preprocess DICOM (see preprocessing code above)
processor = DICOMProcessor()
image = processor.preprocess('path/to/mammogram.dcm')
# Convert to tensor and add batch dimension
image_tensor = torch.from_numpy(image).unsqueeze(0) # Shape: (1, 3, 224, 224)
# Extract embeddings
with torch.no_grad():
embeddings = model(image_tensor) # Shape: (1, 2048)
# L2 normalize (recommended)
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
Performance: PRIMER vs RadImageNet Baseline
PRIMER demonstrates significant improvements over baseline RadImageNet embeddings on mammography-specific evaluation metrics:
| Metric | RadImageNet (Baseline) | PRIMER (Finetuned) | Improvement |
|---|---|---|---|
| Silhouette Score | 0.127 | 0.289 | +127% |
| Davies-Bouldin Score | 2.847 | 1.653 | -42% (lower is better) |
| Calinski-Harabasz Score | 1,834 | 3,621 | +97% |
| Embedding Variance | 0.012 | 0.024 | +100% |
| Intra-cluster Distance | 1.92 | 1.34 | -30% |
| Inter-cluster Distance | 2.15 | 2.87 | +33% |
Key Improvements:
- Better Clustering: Silhouette score increased from 0.127 to 0.289, indicating much tighter and more separated clusters
- Enhanced Discrimination: Davies-Bouldin score decreased by 42%, showing better cluster separation
- Richer Representations: Embedding variance doubled, indicating more diverse and informative features
- Mammography-Specific: Features learned are specialized for mammographic patterns (masses, calcifications, tissue density)
Visualization Improvements
Dimensionality reduction visualizations (t-SNE, UMAP, PCA) show:
- PRIMER embeddings form distinct, well-separated clusters
- RadImageNet embeddings show more overlap and diffuse boundaries
- PRIMER captures mammography-specific visual patterns more effectively
Requirements
torch>=2.1.0
torchvision>=0.16.0
pydicom>=2.4.4
opencv-python>=4.8.1.78
numpy>=1.26.0
timm>=0.9.12
albumentations>=1.3.1
scikit-learn>=1.3.2
Dataset
CMMD (Chinese Mammography Mass Database)
- Modality: Full-field digital mammography (FFDM)
- Format: DICOM files
- Views: CC (craniocaudal), MLO (mediolateral oblique)
- Resolution: Variable (typically 2048×3328 or similar)
Limitations
Domain Specificity: Model is trained on CMMD dataset (Chinese population). Performance may vary on other populations or imaging protocols.
DICOM Format: Requires proper DICOM preprocessing. Standard images (PNG/JPG) must follow the same preprocessing pipeline for best results.
Image Quality: Performance depends on proper CLAHE enhancement and normalization. Poor quality or corrupted DICOM files may produce suboptimal embeddings.
Resolution: Model expects 224×224 input. Very high-resolution details may be lost during resizing.
Self-Supervised: Model uses contrastive learning without labels. Does not perform classification directly - embeddings must be used with downstream tasks (clustering, retrieval, classification).
Photometric Interpretation: Critical to handle MONOCHROME1 vs MONOCHROME2 correctly. Failure to invert MONOCHROME1 images will result in poor embeddings.
Intended Use
Primary Use Cases
- Feature Extraction: Generate embeddings for mammography images
- Similarity Search: Find similar mammograms based on visual features
- Clustering: Group mammograms by visual characteristics
- Transfer Learning: Use as pretrained backbone for downstream tasks (classification, segmentation)
- Retrieval Systems: Content-based mammography image retrieval
- Quality Control: Identify outlier or anomalous mammograms
Out-of-Scope Use Cases
- Direct Diagnosis: Model does not provide diagnostic predictions
- Standalone Clinical Use: Requires integration with clinical workflows and expert interpretation
- Non-Mammography Images: Optimized for mammography; may not generalize to other modalities
- Real-time Processing: Model size (283MB) and preprocessing may not be suitable for real-time applications without optimization
Model Card Contact
For questions or issues, please open an issue on the GitHub repository or contact via HuggingFace discussions.
- Downloads last month
- 13