Source code for deep_transit.train

"""
Main file for training Yolo models on Pascal VOC and COCO dataset
"""

from . import config
import torch
import torch.optim as optim

torch.backends.cudnn.benchmark = True
from .model import YOLOv3
from tqdm.autonotebook import tqdm
from ._utils import (
    seed_everything,
    average_precision,
    get_evaluation_bboxes,
    save_checkpoint,
    load_checkpoint,
    get_loaders,
)
from ._loss import YoloLoss

if config.ENABLE_WANDB:
    import os

    os.environ["WANDB_MODE"] = "offline"
    import wandb

    wandb.init(project='deep_transit',
               config=dict(
                   LEARNING_RATE=config.LEARNING_RATE,
                   WEIGHT_DECAY=config.WEIGHT_DECAY,
                   BATCH_SIZE=config.BATCH_SIZE,
               ))
else:
    class wandb:
        @classmethod
        def log(*args, **kwargs): pass

        @classmethod
        def watch(*args, **kwargs): pass


def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors):
    loop = tqdm(train_loader)
    avg_loss = -1
    for batch_idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE)
        y0, y1, y2 = (
            y[0].to(config.DEVICE),
            y[1].to(config.DEVICE),
            y[2].to(config.DEVICE),
        )

        if config.ENABLE_AMP is True:
            with torch.cuda.amp.autocast():
                out = model(x)
                loss = (
                        loss_fn(out[0], y0, scaled_anchors[0])
                        + loss_fn(out[1], y1, scaled_anchors[1])
                        + loss_fn(out[2], y2, scaled_anchors[2])
                )
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model(x)
            loss = (
                    loss_fn(out[0], y0, scaled_anchors[0])
                    + loss_fn(out[1], y1, scaled_anchors[1])
                    + loss_fn(out[2], y2, scaled_anchors[2])
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # update progress bar
        loop.set_postfix(loss=loss.item())
        if avg_loss == -1:
            avg_loss = loss
        avg_loss = avg_loss * 0.95 + loss * 0.05
        wandb.log({"loss": loss.item(), "avg loss": avg_loss.item()})


[docs]def train(patience=2, cooldown=3, enable_seed_everything=True): """ Function for training your own data set. Parameters ---------- patience: int The parameter of `~torch.optim.lr_scheduler.ReduceLROnPlateau` cooldown: int The parameter of `~torch.optim.lr_scheduler.ReduceLROnPlateau` enable_seed_everything: bool If true, the training will be deterministic """ if enable_seed_everything: seed_everything() model = YOLOv3().to(config.DEVICE) optimizer = optim.Adam( model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY ) loss_fn = YoloLoss() if config.ENABLE_AMP is True: scaler = torch.cuda.amp.GradScaler() else: scaler = None lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=patience, factor=0.5, verbose=True, cooldown=cooldown) tqdm.write(config.DATASET) train_loader, validation_loader = get_loaders( train_csv_path=config.DATASET + "/transit_train.csv", validation_csv_path=config.DATASET + "/transit_val.csv", ) epoch_old = 0 if config.LOAD_MODEL: epoch_old = load_checkpoint( config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE, lr_scheduler ) + 1 scaled_anchors = ( torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, sum([len(x) for x in config.ANCHORS]) // 3, 2) ).to(config.DEVICE) wandb.watch(model) for epoch in range(epoch_old, config.NUM_EPOCHS + 1): train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors) if epoch % 5 == 0 and epoch > 0: if config.SAVE_MODEL: save_checkpoint(model, optimizer, epoch, lr_scheduler, file_path=f"{config.CHECKPOINT_FILE}_{epoch}.tar") tqdm.write("On Validation loader:") pred_boxes, true_boxes = get_evaluation_bboxes( validation_loader, model, iou_threshold=config.NMS_IOU_THRESH, anchors=config.ANCHORS, threshold=config.CONF_THRESHOLD, ) AP50 = average_precision( pred_boxes, true_boxes, iou_threshold=config.MAP_IOU_THRESH, box_format="midpoint", ) AP70 = average_precision( pred_boxes, true_boxes, iou_threshold=0.70, box_format="midpoint", ) AP90 = average_precision( pred_boxes, true_boxes, iou_threshold=0.9, box_format="midpoint", ) mAP = (AP50 + AP70 + AP90) / 3 lr_scheduler.step(mAP) tqdm.write(f"AP50: {AP50:.3f}, AP750: {AP70:.3f}, AP90: {AP90:.3f}") wandb.log({'epoch': epoch, 'ap50': AP50, 'ap70': AP70, 'ap90': AP90, 'mAP': mAP, 'lr': optimizer.param_groups[0]['lr']})
if __name__ == "__main__": train()