import datetime import json from typing import Tuple import numpy as np import ray import ray.train as train import torch import torch.nn as nn import torch.nn.functional as F import typer from ray.air import session from ray.air.config import ( CheckpointConfig, DatasetConfig, RunConfig, ScalingConfig, ) from ray.air.integrations.mlflow import MLflowLoggerCallback from ray.data import Dataset from ray.train.torch import TorchCheckpoint, TorchTrainer from transformers import BertModel from typing_extensions import Annotated from madewithml import data, models, utils from madewithml.config import MLFLOW_TRACKING_URI, logger # Initialize Typer CLI app app = typer.Typer() def train_step( ds: Dataset, batch_size: int, model: nn.Module, num_classes: int, loss_fn: torch.nn.modules.loss._WeightedLoss, optimizer: torch.optim.Optimizer, ) -> float: # pragma: no cover, tested via train workload """Train step. Args: ds (Dataset): dataset to iterate batches from. batch_size (int): size of each batch. model (nn.Module): model to train. num_classes (int): number of classes. loss_fn (torch.nn.loss._WeightedLoss): loss function to use between labels and predictions. optimizer (torch.optimizer.Optimizer): optimizer to use for updating the model's weights. Returns: float: cumulative loss for the dataset. """ model.train() loss = 0.0 ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=utils.collate_fn) for i, batch in enumerate(ds_generator): optimizer.zero_grad() # reset gradients z = model(batch) # forward pass targets = F.one_hot(batch["targets"], num_classes=num_classes).float() # one-hot (for loss_fn) J = loss_fn(z, targets) # define loss J.backward() # backward pass optimizer.step() # update weights loss += (J.detach().item() - loss) / (i + 1) # cumulative loss return loss def eval_step( ds: Dataset, batch_size: int, model: nn.Module, num_classes: int, loss_fn: torch.nn.modules.loss._WeightedLoss ) -> Tuple[float, np.array, np.array]: # pragma: no cover, tested via train workload """Eval step. Args: ds (Dataset): dataset to iterate batches from. batch_size (int): size of each batch. model (nn.Module): model to train. num_classes (int): number of classes. loss_fn (torch.nn.loss._WeightedLoss): loss function to use between labels and predictions. Returns: Tuple[float, np.array, np.array]: cumulative loss, ground truths and predictions. """ model.eval() loss = 0.0 y_trues, y_preds = [], [] ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=utils.collate_fn) with torch.inference_mode(): for i, batch in enumerate(ds_generator): z = model(batch) targets = F.one_hot(batch["targets"], num_classes=num_classes).float() # one-hot (for loss_fn) J = loss_fn(z, targets).item() loss += (J - loss) / (i + 1) y_trues.extend(batch["targets"].cpu().numpy()) y_preds.extend(torch.argmax(z, dim=1).cpu().numpy()) return loss, np.vstack(y_trues), np.vstack(y_preds) def train_loop_per_worker(config: dict) -> None: # pragma: no cover, tested via train workload """Training loop that each worker will execute. Args: config (dict): arguments to use for training. """ # Hyperparameters dropout_p = config["dropout_p"] lr = config["lr"] lr_factor = config["lr_factor"] lr_patience = config["lr_patience"] batch_size = config["batch_size"] num_epochs = config["num_epochs"] num_classes = config["num_classes"] # Get datasets utils.set_seeds() train_ds = session.get_dataset_shard("train") val_ds = session.get_dataset_shard("val") # Model llm = BertModel.from_pretrained("allenai/scibert_scivocab_uncased", return_dict=False) model = models.FinetunedLLM(llm=llm, dropout_p=dropout_p, embedding_dim=llm.config.hidden_size, num_classes=num_classes) model = train.torch.prepare_model(model) # Training components loss_fn = nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=lr_factor, patience=lr_patience) # Training batch_size_per_worker = batch_size // session.get_world_size() for epoch in range(num_epochs): # Step train_loss = train_step(train_ds, batch_size_per_worker, model, num_classes, loss_fn, optimizer) val_loss, _, _ = eval_step(val_ds, batch_size_per_worker, model, num_classes, loss_fn) scheduler.step(val_loss) # Checkpoint metrics = dict(epoch=epoch, lr=optimizer.param_groups[0]["lr"], train_loss=train_loss, val_loss=val_loss) checkpoint = TorchCheckpoint.from_model(model=model) session.report(metrics, checkpoint=checkpoint) @app.command() def train_model( experiment_name: Annotated[str, typer.Option(help="name of the experiment for this training workload.")] = None, dataset_loc: Annotated[str, typer.Option(help="location of the dataset.")] = None, train_loop_config: Annotated[str, typer.Option(help="arguments to use for training.")] = None, num_workers: Annotated[int, typer.Option(help="number of workers to use for training.")] = 1, cpu_per_worker: Annotated[int, typer.Option(help="number of CPUs to use per worker.")] = 1, gpu_per_worker: Annotated[int, typer.Option(help="number of GPUs to use per worker.")] = 0, num_samples: Annotated[int, typer.Option(help="number of samples to use from dataset.")] = None, num_epochs: Annotated[int, typer.Option(help="number of epochs to train for.")] = 1, batch_size: Annotated[int, typer.Option(help="number of samples per batch.")] = 256, results_fp: Annotated[str, typer.Option(help="filepath to save results to.")] = None, ) -> ray.air.result.Result: """Main train function to train our model as a distributed workload. Args: experiment_name (str): name of the experiment for this training workload. dataset_loc (str): location of the dataset. train_loop_config (str): arguments to use for training. num_workers (int, optional): number of workers to use for training. Defaults to 1. cpu_per_worker (int, optional): number of CPUs to use per worker. Defaults to 1. gpu_per_worker (int, optional): number of GPUs to use per worker. Defaults to 0. num_samples (int, optional): number of samples to use from dataset. If this is passed in, it will override the config. Defaults to None. num_epochs (int, optional): number of epochs to train for. If this is passed in, it will override the config. Defaults to None. batch_size (int, optional): number of samples per batch. If this is passed in, it will override the config. Defaults to None. results_fp (str, optional): filepath to save results to. Defaults to None. Returns: ray.air.result.Result: training results. """ # Set up train_loop_config = json.loads(train_loop_config) train_loop_config["num_samples"] = num_samples train_loop_config["num_epochs"] = num_epochs train_loop_config["batch_size"] = batch_size # Scaling config scaling_config = ScalingConfig( num_workers=num_workers, use_gpu=bool(gpu_per_worker), resources_per_worker={"CPU": cpu_per_worker, "GPU": gpu_per_worker}, _max_cpu_fraction_per_node=0.8, ) # Checkpoint config checkpoint_config = CheckpointConfig( num_to_keep=1, checkpoint_score_attribute="val_loss", checkpoint_score_order="min", ) # MLflow callback mlflow_callback = MLflowLoggerCallback( tracking_uri=MLFLOW_TRACKING_URI, experiment_name=experiment_name, save_artifact=True, ) # Run config run_config = RunConfig( callbacks=[mlflow_callback], checkpoint_config=checkpoint_config, ) # Dataset ds = data.load_data(dataset_loc=dataset_loc, num_samples=train_loop_config["num_samples"]) train_ds, val_ds = data.stratify_split(ds, stratify="tag", test_size=0.2) tags = train_ds.unique(column="tag") train_loop_config["num_classes"] = len(tags) # Dataset config dataset_config = { "train": DatasetConfig(fit=False, transform=False, randomize_block_order=False), "val": DatasetConfig(fit=False, transform=False, randomize_block_order=False), } # Preprocess preprocessor = data.CustomPreprocessor() train_ds = preprocessor.fit_transform(train_ds) val_ds = preprocessor.transform(val_ds) train_ds = train_ds.materialize() val_ds = val_ds.materialize() # Trainer trainer = TorchTrainer( train_loop_per_worker=train_loop_per_worker, train_loop_config=train_loop_config, scaling_config=scaling_config, run_config=run_config, datasets={"train": train_ds, "val": val_ds}, dataset_config=dataset_config, preprocessor=preprocessor, ) # Train results = trainer.fit() d = { "timestamp": datetime.datetime.now().strftime("%B %d, %Y %I:%M:%S %p"), "run_id": utils.get_run_id(experiment_name=experiment_name, trial_id=results.metrics["trial_id"]), "params": results.config["train_loop_config"], "metrics": utils.dict_to_list(results.metrics_dataframe.to_dict(), keys=["epoch", "train_loss", "val_loss"]), } logger.info(json.dumps(d, indent=2)) if results_fp: # pragma: no cover, saving results utils.save_dict(d, results_fp) return results if __name__ == "__main__": # pragma: no cover, application if ray.is_initialized(): ray.shutdown() ray.init() app()