Ray Train — TorchTrainer, distributed training across GPUs/nodes, fault tolerance, local mode with torchrun, HuggingFace integration, Ray Data pipelines, and Tune. Use when running distributed training with Ray. NOT for single-GPU training.
Installation
Details
Usage
After installing, this skill will be available to your AI coding assistant.
Verify installation:
npx agent-skills-cli listSkill Instructions
name: ray-train description: "Ray Train — TorchTrainer, distributed training across GPUs/nodes, fault tolerance, local mode with torchrun, HuggingFace integration, Ray Data pipelines, and Tune. Use when running distributed training with Ray. NOT for single-GPU training."
Ray Train
Ray Train provides distributed training built on Ray, abstracting away distributed setup and adding fault tolerance, checkpointing, and integration with Ray Data and Ray Tune. Version: 2.54.0 — Local mode (backfilled from 2.50.0) with torchrun for multi-process debugging, enhanced checkpoint upload modes, Train V2 as default API.
TorchTrainer Configuration
Core Settings
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig, RunConfig, CheckpointConfig, FailureConfig
trainer = TorchTrainer(
train_loop_per_worker=train_func,
scaling_config=ScalingConfig(...),
run_config=RunConfig(...),
torch_config=TorchConfig(...),
datasets={"train": ray_dataset}, # optional Ray Data integration
dataset_config=DataConfig(...), # optional data config
resume_from_checkpoint=checkpoint, # resume from saved checkpoint
)
result = trainer.fit()
ScalingConfig
| Setting | Purpose | Default |
|---|---|---|
num_workers | Total training workers (each gets a GPU) | required |
use_gpu | Assign GPUs to workers | False |
resources_per_worker | Resource dict per worker | {} |
placement_strategy | SPREAD, PACK, STRICT_SPREAD, STRICT_PACK | PACK |
trainer_resources | Resources for the trainer coordinator | {"CPU": 0} |
accelerator_type | Required accelerator (e.g., A100, H100) | None |
scaling_config = ScalingConfig(
num_workers=8,
use_gpu=True,
resources_per_worker={"CPU": 8, "GPU": 1},
placement_strategy="PACK", # co-locate workers for faster communication
accelerator_type="A100", # require specific GPU type
)
Placement strategies:
PACK— co-locate workers on same node (best for TP, faster communication)SPREAD— distribute across nodes (better fault isolation)STRICT_PACK/STRICT_SPREAD— hard constraints (fail if can't satisfy)
Fault Tolerance and Resiliency
CheckpointConfig
checkpoint_config = CheckpointConfig(
num_to_keep=3, # keep only last N checkpoints
checkpoint_score_attribute="eval_loss", # metric to rank checkpoints
checkpoint_score_order="min", # "min" or "max"
)
FailureConfig
from ray.train import FailureConfig
failure_config = FailureConfig(
max_failures=3, # auto-restart up to 3 times on failure
fail_fast=False, # True = fail immediately on any worker error
)
When a worker fails, Ray Train:
- Stops all workers
- Restores from the latest checkpoint
- Restarts all workers from that checkpoint
- Continues training
Checkpointing in the Training Loop
from ray.train import Checkpoint
import tempfile, os
def train_func(config):
model = ...
optimizer = ...
# Resume from checkpoint if available
checkpoint = ray.train.get_checkpoint()
start_epoch = 0
if checkpoint:
with checkpoint.as_directory() as ckpt_dir:
state = torch.load(os.path.join(ckpt_dir, "model.pt"))
model.load_state_dict(state["model"])
optimizer.load_state_dict(state["optimizer"])
start_epoch = state["epoch"] + 1
for epoch in range(start_epoch, num_epochs):
train_one_epoch(model, optimizer)
# Save checkpoint (only rank 0 saves, but all ranks must call report)
with tempfile.TemporaryDirectory() as tmp:
checkpoint = None
if ray.train.get_context().get_world_rank() == 0:
torch.save({
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
}, os.path.join(tmp, "model.pt"))
checkpoint = Checkpoint.from_directory(tmp)
ray.train.report(
metrics={"loss": loss, "epoch": epoch},
checkpoint=checkpoint,
)
Storage Configuration
⚠️ DATA LOSS: Default storage_path destroys checkpoints on K8s.
RunConfig(storage_path=...) defaults to ~/ray_results — the head pod's ephemeral disk. With shutdownAfterJobFinishes: true, KubeRay deletes the entire RayCluster when the job completes. The head pod is evicted, and every checkpoint, metric log, and artifact is gone. The job reports SUCCEEDED. There is no warning. Weeks of GPU time can vanish silently.
Always set storage_path explicitly to S3/GCS or a ReadWriteMany PVC:
# Option 1: S3 (recommended — no PVC wiring needed)
RunConfig(storage_path="s3://my-bucket/ray-results")
# Option 2: Shared PVC mounted on ALL pods (head + workers)
RunConfig(storage_path="/mnt/shared-checkpoints")
See Checkpoint-Safe PVC Wiring for KubeRay for the full K8s manifest.
from ray.train import RunConfig
run_config = RunConfig(
name="my-training-run",
storage_path="s3://my-bucket/ray-results", # or /mnt/shared-pvc
checkpoint_config=checkpoint_config,
failure_config=failure_config,
log_to_file=True,
)
Storage requirements:
- All workers must be able to read/write to
storage_path - On K8s: use a shared PVC or S3/GCS with credentials in the container env
- Checkpoint upload is async by default
RunConfig
| Setting | Purpose | Default |
|---|---|---|
name | Run name (directory name) | Auto-generated |
storage_path | Where to save checkpoints/results | ~/ray_results |
checkpoint_config | Checkpoint settings | CheckpointConfig() |
failure_config | Failure handling | FailureConfig() |
log_to_file | Redirect stdout/stderr to files | False |
stop | Stopping criteria | None |
callbacks | List of callbacks | [] |
sync_config | Sync settings for cloud storage | Auto |
TorchConfig
| Setting | Purpose | Default |
|---|---|---|
backend | Distributed backend | "nccl" (GPU), "gloo" (CPU) |
timeout_s | Timeout for collective operations (seconds) | 1800 |
from ray.train.torch import TorchConfig
torch_config = TorchConfig(
backend="nccl",
timeout_s=3600, # increase for large models or slow checkpointing
)
Local Mode (Development & Debugging)
Run training without launching Ray worker actors — execute directly in the current process for rapid iteration, unit testing, and debugging. Enable by setting num_workers=0:
trainer = TorchTrainer(
train_loop_per_worker=train_func,
scaling_config=ScalingConfig(num_workers=0), # local mode
)
result = trainer.fit()
All ray.train APIs (report(), get_checkpoint(), get_context()) work identically in local mode — no code changes needed when switching to distributed.
Multi-Process Local Mode with torchrun
Test multi-GPU distributed logic using torchrun's process management — useful for debugging NCCL, DDP, and FSDP without Ray actors:
# train_script.py — same Ray Train code, launched via torchrun
import torch
from torch.utils.data import DataLoader, TensorDataset
from ray.train import ScalingConfig, get_context
from ray.train.torch import TorchTrainer
def train_func(config):
model = torch.nn.Linear(10, 1) # replace with your model
dataset = TensorDataset(torch.randn(100, 10), torch.randn(100, 1))
loader = DataLoader(dataset, batch_size=16)
model = ray.train.torch.prepare_model(model)
loader = ray.train.torch.prepare_data_loader(loader)
# get_context() returns correct world_size/rank from torchrun env
print(f"Rank {get_context().get_world_rank()} of {get_context().get_world_size()}")
# ... training loop with ray.train.report(...)
trainer = TorchTrainer(
train_loop_per_worker=train_func,
scaling_config=ScalingConfig(num_workers=0, use_gpu=True),
)
result = trainer.fit()
# Launch multi-GPU training on 4 GPUs (RAY_TRAIN_V2_ENABLED=1 required for Train V2 API)
RAY_TRAIN_V2_ENABLED=1 torchrun --nproc-per-node=4 train_script.py
# Multi-node with torchrun
RAY_TRAIN_V2_ENABLED=1 torchrun --nnodes=2 --nproc-per-node=8 --rdzv-endpoint=head:29500 train_script.py
Ray Train auto-detects the torchrun environment and configures distributed training accordingly. Note: Ray Data is not supported with torchrun multi-process mode — use standard DataLoader with DistributedSampler.
HuggingFace Transformers Integration
from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback
def train_func(config):
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir="/tmp/hf-output",
per_device_train_batch_size=config["batch_size"],
num_train_epochs=config["epochs"],
learning_rate=config["lr"],
bf16=True,
gradient_accumulation_steps=config.get("grad_accum", 1),
save_strategy="epoch",
logging_steps=10,
# Do NOT set deepspeed or fsdp here — Ray Train handles distribution
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
callbacks=[RayTrainReportCallback()], # reports metrics + checkpoints to Ray
)
trainer = prepare_trainer(trainer) # sets up distributed training
trainer.train()
Key: Use prepare_trainer() instead of configuring DeepSpeed/FSDP directly in TrainingArguments. Ray Train handles the distributed setup.
Data Ingestion with Ray Data
Streaming Data Pipeline
import ray
# Create a streaming dataset (doesn't load all data into memory)
ds = ray.data.read_parquet("s3://my-bucket/training-data/")
# Preprocessing pipeline
ds = ds.map(tokenize_function)
ds = ds.random_shuffle()
trainer = TorchTrainer(
train_loop_per_worker=train_func,
scaling_config=ScalingConfig(num_workers=8, use_gpu=True),
datasets={"train": ds, "eval": eval_ds},
)
DataConfig
from ray.train import DataConfig
data_config = DataConfig(
datasets_to_split="all", # split across workers (default for train)
)
In the Training Loop
def train_func(config):
train_ds = ray.train.get_dataset_shard("train")
for epoch in range(num_epochs):
# Iterate with batches — streaming, memory-efficient
for batch in train_ds.iter_torch_batches(batch_size=32, device="cuda"):
loss = model(batch["input_ids"], batch["labels"])
...
Ray Tune Integration
from ray import tune
trainer = TorchTrainer(
train_loop_per_worker=train_func,
scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
)
tuner = tune.Tuner(
trainer,
param_space={
"train_loop_config": {
"lr": tune.loguniform(1e-5, 1e-3),
"batch_size": tune.choice([16, 32, 64]),
"epochs": 3,
},
},
tune_config=tune.TuneConfig(
num_samples=10,
metric="eval_loss",
mode="min",
scheduler=tune.schedulers.ASHAScheduler(max_t=3, grace_period=1),
),
)
results = tuner.fit()
Multi-Node Training on Kubernetes
Deploy as a RayJob with shutdownAfterJobFinishes: true:
apiVersion: ray.io/v1
kind: RayJob
metadata:
name: distributed-training
spec:
entrypoint: python train.py
shutdownAfterJobFinishes: true
activeDeadlineSeconds: 86400
submissionMode: K8sJobMode
rayClusterSpec:
headGroupSpec:
template:
spec:
containers:
- name: ray-head
resources:
limits:
cpu: "4"
memory: 16Gi
workerGroupSpecs:
- groupName: gpu-workers
replicas: 8
template:
spec:
containers:
- name: ray-worker
resources:
limits:
cpu: "8"
memory: 64Gi
nvidia.com/gpu: "1"
Shared Storage for Checkpoints
Mount a shared PVC or use S3/GCS:
run_config = RunConfig(
storage_path="/mnt/shared-checkpoints", # PVC mounted on all pods
# or: storage_path="s3://bucket/ray-results",
)
Checkpoint-Safe PVC Wiring for KubeRay
Complete RayJob manifest with a ReadWriteMany PVC mounted on head and all worker pods. This prevents checkpoint loss when shutdownAfterJobFinishes: true deletes the cluster.
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: ray-checkpoints
spec:
accessModes: ["ReadWriteMany"] # required — all pods must write
storageClassName: efs-sc # EFS, Longhorn, CephFS, or any RWX provider
resources:
requests:
storage: 200Gi
---
apiVersion: ray.io/v1
kind: RayJob
metadata:
name: distributed-training
spec:
entrypoint: python train.py
shutdownAfterJobFinishes: true
activeDeadlineSeconds: 86400
submissionMode: K8sJobMode
rayClusterSpec:
headGroupSpec:
template:
spec:
containers:
- name: ray-head
resources:
limits:
cpu: "4"
memory: 16Gi
volumeMounts:
- name: checkpoints
mountPath: /mnt/shared-checkpoints
volumes:
- name: checkpoints
persistentVolumeClaim:
claimName: ray-checkpoints
workerGroupSpecs:
- groupName: gpu-workers
replicas: 8
template:
spec:
containers:
- name: ray-worker
resources:
limits:
cpu: "8"
memory: 64Gi
nvidia.com/gpu: "1"
volumeMounts:
- name: checkpoints
mountPath: /mnt/shared-checkpoints
volumes:
- name: checkpoints
persistentVolumeClaim:
claimName: ray-checkpoints
In the training script:
run_config = RunConfig(
storage_path="/mnt/shared-checkpoints", # survives cluster teardown
checkpoint_config=CheckpointConfig(num_to_keep=3),
)
Debugging
For communication errors, OOM, slow data loading, and common training issues, see references/troubleshooting.md.
References
troubleshooting.md— Distributed training failures, checkpoint issues, and scaling problems
Cross-References
- ray-core — Ray tasks, actors, and object store fundamentals
- ray-data — Streaming data pipelines for training
- aws-efa — EFA networking for multi-node Ray Train on EKS
- aws-fsx — FSx storage for training data and checkpoints
- kuberay — Deploy training jobs on Kubernetes via RayJob CRD
- pytorch — PyTorch distributed training concepts
- fsdp — FSDP for model parallelism within Ray Train
- deepspeed — DeepSpeed integration with Ray Train
- wandb — Experiment tracking for Ray Train runs
- nccl — NCCL tuning for multi-node Ray Train GPU communication
- gpu-operator — GPU driver and device plugin for Ray Train workers
- kueue — Queue Ray Train jobs via KubeRay integration
Reference
- Ray Train docs
- TorchTrainer API
- Ray Train examples
references/troubleshooting.md— common errors and fixesassets/train_llm.py— complete Ray Train LLM fine-tuning example with FSDP, checkpointing, fault tolerance, and Ray Data integrationassets/architecture.md— Mermaid architecture diagrams
More by tylertitsworth
View allLLM evaluation with lm-evaluation-harness — MMLU, HumanEval, GSM8K benchmarks, custom tasks, vLLM/HF/OpenAI backends, metrics, and LLM-as-judge. Use when evaluating or benchmarking language models. NOT for training, fine-tuning, dataset preprocessing, or model serving.
FSx for Lustre — performance tuning, striping, S3 data repositories, EKS integration. Use when configuring high-performance storage for ML on EKS. NOT for EBS or EFS.
verl (Volcano Engine RL) — PPO, GRPO, DAPO, GSPO, RLOO, TIS (token/sequence importance sampling), rollout server mode, reward models, rule-based rewards, vLLM/SGLang rollout, and multi-GPU FSDP/Megatron training. Use when doing RLHF or RL post-training on LLMs.
uv — fast Python package/project manager, lockfiles, Python versions, uvx tool runner, Docker/CI integration. Use for Python dependency management. NOT for package publishing.
