rpmjp/projects/swin-transformer-study/swin_training.py
CompletedMay to Dec 2025
Swin Transformer: Empirical Evaluation on Small Fine-Grained Data
A controlled four-family architecture comparison (Swin T/S/B, RegNetY CNNs, EfficientNet B3-B7, ViT-B/16) on the Oxford-IIIT Pet Dataset under RTX 4090 constraints. Three findings: Swin's hierarchical attention transfers cleanly to small datasets (93.8-96.35%), EfficientNet's compound scaling breaks (B3 beats B7 by 8.66 points), and ViT catastrophically fails (7.17%: barely above the 2.7% random baseline).
PyTorchtimmSwin-T/S/BRegNetYEfficientNet B3-B7ViT-B/16Oxford-IIIT PetRTX 4090
Languages
Jupyter Notebook98%
Python2%
swin_training.py
"""
Swin Transformer training loop: Swin-T / Swin-S / Swin-B on the Oxford-IIIT
Pet Dataset. Standard transfer learning from ImageNet pretrained weights via
timm, with the classifier head replaced for 37 classes.
The full configuration was 5 epochs, AdamW, CrossEntropyLoss, batch size 64
at 224² input: same hyperparameters across all Swin variants, so the only
thing that changed between runs was the model name.
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import timm
from torchvision import transforms
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 37
# ImageNet normalization: every model in the study uses these statistics so
# the transfer-learning starting point is identical across families.
_NORMALIZE = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
def build_swin(model_name: str, num_classes: int = NUM_CLASSES) -> nn.Module:
"""Load a pretrained Swin variant and swap the classifier head.
Supported model_name values used in this study:
- 'swin_tiny_patch4_window7_224'
- 'swin_small_patch4_window7_224'
- 'swin_base_patch4_window7_224'
- 'swin_base_patch4_window12_384'
"""
model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
return model.to(DEVICE)
def build_transforms(image_size: int) -> transforms.Compose:
"""Standard ImageNet preprocessing: Resize, ToTensor, Normalize."""
return transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
_NORMALIZE,
])
def train_one_epoch(model, loader, optimizer, criterion):
model.train()
total_loss, total_correct, total_seen = 0.0, 0, 0
for images, labels in loader:
images = images.to(DEVICE, non_blocking=True)
labels = labels.to(DEVICE, non_blocking=True)
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * labels.size(0)
total_correct += (logits.argmax(dim=1) == labels).sum().item()
total_seen += labels.size(0)
return total_loss / total_seen, total_correct / total_seen
@torch.no_grad()
def validate(model, loader, criterion):
model.eval()
total_loss, total_correct, total_seen = 0.0, 0, 0
for images, labels in loader:
images = images.to(DEVICE, non_blocking=True)
labels = labels.to(DEVICE, non_blocking=True)
logits = model(images)
loss = criterion(logits, labels)
total_loss += loss.item() * labels.size(0)
total_correct += (logits.argmax(dim=1) == labels).sum().item()
total_seen += labels.size(0)
return total_loss / total_seen, total_correct / total_seen
def run(model_name: str, image_size: int, train_ds, val_ds, epochs: int = 5,
batch_size: int = 64, lr: float = 1e-3):
"""Full training driver. Identical config across Swin variants; only
model_name and image_size change."""
model = build_swin(model_name)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True)
for epoch in range(1, epochs + 1):
tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, criterion)
va_loss, va_acc = validate(model, val_loader, criterion)
print(
f"[{model_name}] epoch {epoch}: "
f"train_loss={tr_loss:.4f} train_acc={tr_acc:.4f} "
f"val_loss={va_loss:.4f} val_acc={va_acc:.4f}"
)
return model