BNNR

Python API Reference

What you will find here

User-facing Python API for integrating BNNR with your own model and dataloaders.

When to use this page

Use this when CLI presets are not enough and you need full control.

Source of truth

This page documents only symbols exported publicly from src/bnnr/__init__.py.

Core training API

  • BNNRConfig
  • BNNRTrainer
  • quick_run
  • BNNRRunResult
  • CheckpointInfo

Model adapter API

  • ModelAdapter
  • XAICapableModel
  • SimpleTorchAdapter

Reporting and events API

  • Reporter
  • load_report
  • compare_runs
  • JsonlEventSink
  • EVENT_SCHEMA_VERSION
  • replay_events

Config helpers

  • load_config
  • save_config
  • validate_config
  • merge_configs
  • apply_xai_preset
  • get_xai_preset
  • list_xai_presets

Augmentation API

  • BaseAugmentation
  • AugmentationRegistry
  • AugmentationRunner
  • TorchvisionAugmentation
  • KorniaAugmentation
  • AlbumentationsAugmentation
  • create_kornia_pipeline
  • kornia_available
  • albumentations_available

Built-in classification augmentations:

  • ChurchNoise
  • BasicAugmentation
  • DifPresets
  • Drust
  • LuxferGlass
  • ProCAM
  • Smugs
  • TeaStains

Preset helpers:

  • auto_select_augmentations
  • get_preset
  • list_presets

XAI API (classification)

Explainers and generation:

  • BaseExplainer
  • OptiCAMExplainer
  • NMFConceptExplainer
  • CRAFTExplainer
  • RealCRAFTExplainer
  • RecursiveCRAFTExplainer
  • generate_saliency_maps
  • generate_craft_concepts
  • generate_nmf_concepts
  • save_xai_visualization

Analysis and scoring:

  • analyze_xai_batch
  • analyze_xai_batch_rich
  • compute_xai_quality_score
  • generate_class_diagnosis
  • generate_class_insight
  • generate_epoch_summary
  • generate_rich_epoch_summary

Cache:

  • XAICache

ICD variants:

  • ICD
  • AICD

Dashboard helper

  • start_dashboard

Minimal classification integration

import torch
import torch.nn as nn
from bnnr import BNNRConfig, BNNRTrainer, SimpleTorchAdapter, auto_select_augmentations
 
model = ...
train_loader = ...  # (image, label, index)
val_loader = ...
 
adapter = SimpleTorchAdapter(
    model=model,
    criterion=nn.CrossEntropyLoss(),
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
    target_layers=[...],
    device="auto",
)
 
config = BNNRConfig(m_epochs=3, max_iterations=2, device="auto")
trainer = BNNRTrainer(adapter, train_loader, val_loader, auto_select_augmentations(), config)
result = trainer.run()
print(result.best_metrics)

quick_run() helper

quick_run() builds SimpleTorchAdapter internally.

from bnnr import quick_run
 
result = quick_run(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
)

Useful arguments include augmentations, config/overrides, criterion, optimizer, target_layers, and eval_metrics.

Detection

Model Adapters

  • DetectionAdapter(model, optimizer, target_layers=None, device="cuda", scheduler=None, use_amp=False, score_threshold=0.05) — wraps torchvision-style detectors (Faster R-CNN, RetinaNet, SSD, FCOS). In train mode calls model(images, targets) for losses; in eval mode calls model(images) for prediction dicts.
  • UltralyticsDetectionAdapter(model_name="yolov8n.pt", device="cuda", score_threshold=0.05, num_classes=None, lr=1e-3, optimizer=None, use_amp=False) — wraps Ultralytics YOLO. Exposes predict_detection_dicts(batch_bchw) for XAI and probe snapshots.

Both adapters implement train_step, eval_step, epoch_end_eval, epoch_end, state_dict, load_state_dict, get_target_layers, and get_model.

Collate Functions

  • detection_collate_fn(batch)(Tensor[B,C,H,W], list[dict])
  • detection_collate_fn_with_index(batch)(Tensor[B,C,H,W], list[dict], Tensor[B])

Detection Augmentations

Bbox-aware transforms (subclass BboxAwareAugmentation):

  • DetectionHorizontalFlip, DetectionVerticalFlip, DetectionRandomRotate90
  • DetectionRandomScale(scale_range=(0.8, 1.2))
  • MosaicAugmentation(output_size=(640, 640)), DetectionMixUp(alpha_range=(0.3, 0.7))
  • AlbumentationsBboxAugmentation(transform)

XAI-driven: DetectionICD, DetectionAICD — saliency-based tile masking for detection.

Presets: get_detection_preset(name) with name{"light", "standard", "aggressive"}.

Detection Metrics

  • calculate_detection_metrics(predictions, targets, iou_thresholds=None, score_threshold=0.0){"map_50", "map_50_95"}
  • calculate_per_class_ap(predictions, targets, iou_threshold=0.5, class_names=None) → per-class AP dict
  • calculate_detection_confusion_matrix(predictions, targets, num_classes=None, iou_threshold=0.5){"labels", "matrix"}

Detection XAI

  • generate_detection_saliency(...) — backbone activation–based class-agnostic saliency
  • compute_detection_box_saliency_occlusion(...) — per-box occlusion grid saliency
  • draw_boxes_on_image(...) — draw xyxy boxes with labels and scores
  • overlay_saliency_heatmap(...) — blend saliency with colormap
  • save_detection_xai_panels(...) — writes ground-truth, saliency, and prediction triptych

See Detection for the full guide with examples.