updated to Ray 2.7
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -10,21 +11,23 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import typer
|
||||
from ray.air import session
|
||||
from ray.air.config import (
|
||||
from ray.air.integrations.mlflow import MLflowLoggerCallback
|
||||
from ray.data import Dataset
|
||||
from ray.train import (
|
||||
Checkpoint,
|
||||
CheckpointConfig,
|
||||
DatasetConfig,
|
||||
DataConfig,
|
||||
RunConfig,
|
||||
ScalingConfig,
|
||||
)
|
||||
from ray.air.integrations.mlflow import MLflowLoggerCallback
|
||||
from ray.data import Dataset
|
||||
from ray.train.torch import TorchCheckpoint, TorchTrainer
|
||||
from ray.train.torch import TorchTrainer
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from transformers import BertModel
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from madewithml import data, models, utils
|
||||
from madewithml import data, utils
|
||||
from madewithml.config import EFS_DIR, MLFLOW_TRACKING_URI, logger
|
||||
from madewithml.models import FinetunedLLM
|
||||
|
||||
# Initialize Typer CLI app
|
||||
app = typer.Typer()
|
||||
@@ -106,18 +109,18 @@ def train_loop_per_worker(config: dict) -> None: # pragma: no cover, tested via
|
||||
lr = config["lr"]
|
||||
lr_factor = config["lr_factor"]
|
||||
lr_patience = config["lr_patience"]
|
||||
batch_size = config["batch_size"]
|
||||
num_epochs = config["num_epochs"]
|
||||
batch_size = config["batch_size"]
|
||||
num_classes = config["num_classes"]
|
||||
|
||||
# Get datasets
|
||||
utils.set_seeds()
|
||||
train_ds = session.get_dataset_shard("train")
|
||||
val_ds = session.get_dataset_shard("val")
|
||||
train_ds = train.get_dataset_shard("train")
|
||||
val_ds = train.get_dataset_shard("val")
|
||||
|
||||
# Model
|
||||
llm = BertModel.from_pretrained("allenai/scibert_scivocab_uncased", return_dict=False)
|
||||
model = models.FinetunedLLM(llm=llm, dropout_p=dropout_p, embedding_dim=llm.config.hidden_size, num_classes=num_classes)
|
||||
model = FinetunedLLM(llm=llm, dropout_p=dropout_p, embedding_dim=llm.config.hidden_size, num_classes=num_classes)
|
||||
model = train.torch.prepare_model(model)
|
||||
|
||||
# Training components
|
||||
@@ -126,7 +129,8 @@ def train_loop_per_worker(config: dict) -> None: # pragma: no cover, tested via
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=lr_factor, patience=lr_patience)
|
||||
|
||||
# Training
|
||||
batch_size_per_worker = batch_size // session.get_world_size()
|
||||
num_workers = train.get_context().get_world_size()
|
||||
batch_size_per_worker = batch_size // num_workers
|
||||
for epoch in range(num_epochs):
|
||||
# Step
|
||||
train_loss = train_step(train_ds, batch_size_per_worker, model, num_classes, loss_fn, optimizer)
|
||||
@@ -134,9 +138,14 @@ def train_loop_per_worker(config: dict) -> None: # pragma: no cover, tested via
|
||||
scheduler.step(val_loss)
|
||||
|
||||
# Checkpoint
|
||||
metrics = dict(epoch=epoch, lr=optimizer.param_groups[0]["lr"], train_loss=train_loss, val_loss=val_loss)
|
||||
checkpoint = TorchCheckpoint.from_model(model=model)
|
||||
session.report(metrics, checkpoint=checkpoint)
|
||||
with tempfile.TemporaryDirectory() as dp:
|
||||
if isinstance(model, DistributedDataParallel): # cpu
|
||||
model.module.save(dp=dp)
|
||||
else:
|
||||
model.save(dp=dp)
|
||||
metrics = dict(epoch=epoch, lr=optimizer.param_groups[0]["lr"], train_loss=train_loss, val_loss=val_loss)
|
||||
checkpoint = Checkpoint.from_directory(dp)
|
||||
train.report(metrics, checkpoint=checkpoint)
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -183,7 +192,6 @@ def train_model(
|
||||
num_workers=num_workers,
|
||||
use_gpu=bool(gpu_per_worker),
|
||||
resources_per_worker={"CPU": cpu_per_worker, "GPU": gpu_per_worker},
|
||||
_max_cpu_fraction_per_node=0.8,
|
||||
)
|
||||
|
||||
# Checkpoint config
|
||||
@@ -201,7 +209,7 @@ def train_model(
|
||||
)
|
||||
|
||||
# Run config
|
||||
run_config = RunConfig(callbacks=[mlflow_callback], checkpoint_config=checkpoint_config, storage_path=EFS_DIR)
|
||||
run_config = RunConfig(callbacks=[mlflow_callback], checkpoint_config=checkpoint_config, storage_path=EFS_DIR, local_dir=EFS_DIR)
|
||||
|
||||
# Dataset
|
||||
ds = data.load_data(dataset_loc=dataset_loc, num_samples=train_loop_config["num_samples"])
|
||||
@@ -210,14 +218,13 @@ def train_model(
|
||||
train_loop_config["num_classes"] = len(tags)
|
||||
|
||||
# Dataset config
|
||||
dataset_config = {
|
||||
"train": DatasetConfig(fit=False, transform=False, randomize_block_order=False),
|
||||
"val": DatasetConfig(fit=False, transform=False, randomize_block_order=False),
|
||||
}
|
||||
options = ray.data.ExecutionOptions(preserve_order=True)
|
||||
dataset_config = DataConfig(datasets_to_split=["train"], execution_options=options)
|
||||
|
||||
# Preprocess
|
||||
preprocessor = data.CustomPreprocessor()
|
||||
train_ds = preprocessor.fit_transform(train_ds)
|
||||
preprocessor = preprocessor.fit(train_ds)
|
||||
train_ds = preprocessor.transform(train_ds)
|
||||
val_ds = preprocessor.transform(val_ds)
|
||||
train_ds = train_ds.materialize()
|
||||
val_ds = val_ds.materialize()
|
||||
@@ -230,7 +237,7 @@ def train_model(
|
||||
run_config=run_config,
|
||||
datasets={"train": train_ds, "val": val_ds},
|
||||
dataset_config=dataset_config,
|
||||
preprocessor=preprocessor,
|
||||
metadata={"class_to_index": preprocessor.class_to_index},
|
||||
)
|
||||
|
||||
# Train
|
||||
|
||||
Reference in New Issue
Block a user