rpmjp/portfolio
rpmjp/projects/swin-transformer-study/vit_evaluation.py
CompletedMay to Dec 2025

Swin Transformer: Empirical Evaluation on Small Fine-Grained Data

A controlled four-family architecture comparison (Swin T/S/B, RegNetY CNNs, EfficientNet B3-B7, ViT-B/16) on the Oxford-IIIT Pet Dataset under RTX 4090 constraints. Three findings: Swin's hierarchical attention transfers cleanly to small datasets (93.8-96.35%), EfficientNet's compound scaling breaks (B3 beats B7 by 8.66 points), and ViT catastrophically fails (7.17%: barely above the 2.7% random baseline).

PyTorchtimmSwin-T/S/BRegNetYEfficientNet B3-B7ViT-B/16Oxford-IIIT PetRTX 4090
Languages
Jupyter Notebook98%
Python2%
vit_evaluation.py
"""
ViT-B/16 evaluation on Oxford-IIIT Pet: the catastrophic-failure case.

ViT-B/16 was trained with ImageNet-21K → 1K pretrained weights, 384² input,
the same AdamW + CrossEntropyLoss + 5-epoch protocol as every other model in
the study. The result: 7.17% validation accuracy against a 2.7% random
baseline. Training accuracy peaked at 8.03%: the model did not even fit the
training set, let alone generalize.

ViT-L/16 was attempted with the same protocol and ran out of memory
immediately on the 24GB RTX 4090. It does not appear in results because it
could not be trained.

See `baseline-analysis.md` for why this happens structurally.
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import timm
from torchvision import transforms


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 37


VIT_MODELS = {
    "ViT-B/16": {"timm_name": "vit_base_patch16_384", "image_size": 384},
    # ViT-L/16 included for documentation. It OOMs on a single 24GB card.
    "ViT-L/16": {"timm_name": "vit_large_patch16_384", "image_size": 384},
}


def build_transforms(image_size: int) -> transforms.Compose:
    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])


def train_vit(model_label: str, train_ds, val_ds, epochs: int = 5,
              batch_size: int = 64, lr: float = 1e-3):
    """Same training protocol as the rest of the study. The point is that
    nothing about the protocol favors the CNN families: ViT got the same
    config and still failed."""
    spec = VIT_MODELS[model_label]

    try:
        model = timm.create_model(spec["timm_name"], pretrained=True,
                                  num_classes=NUM_CLASSES).to(DEVICE)
    except torch.cuda.OutOfMemoryError:
        # Documented outcome for ViT-L/16 on this hardware.
        print(f"[{model_label}] OOM on model load: did not train.")
        return None

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            num_workers=4, pin_memory=True)

    for epoch in range(1, epochs + 1):
        model.train()
        train_correct, train_seen = 0, 0
        for images, labels in train_loader:
            images = images.to(DEVICE); labels = labels.to(DEVICE)
            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward(); optimizer.step()
            train_correct += (logits.argmax(1) == labels).sum().item()
            train_seen += labels.size(0)
        train_acc = train_correct / train_seen

        model.eval()
        val_correct, val_seen = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(DEVICE); labels = labels.to(DEVICE)
                val_correct += (model(images).argmax(1) == labels).sum().item()
                val_seen += labels.size(0)
        val_acc = val_correct / val_seen

        # Diagnostic print: for ViT, this number will sit near random.
        print(
            f"[{model_label}] epoch {epoch}: "
            f"train_acc={train_acc:.4f} val_acc={val_acc:.4f}"
        )

    return model