Skip to main content

Creating a Trainer Plugin Script

This guide explains how to adapt your existing training scripts to work with Transformer Lab using the tlab_trainer decorator class. By integrating with Transformer Lab, your training scripts gain progress tracking, parameter management, dataset handling, and integrated logging with minimal code changes. This is a part of the active development we are conducting with the Transformer Lab Plugin SDK to make integrating third-party plugins easier.

What is tlab_trainer?​

tlab_trainer is a decorator class that helps integrate your training script with Transformer Lab's job management system. It provides:

  • Argument parsing and configuration loading
  • Dataset loading helpers
  • Progress tracking and reporting
  • Job status management
  • Integration with TensorBoard and Weights & Biases

Getting Started​

1. Import the decorator​

Add this import to your training script:

from transformerlab.sdk.v1.train import tlab_trainer

2. Decorate your main training function​

Wrap your main training function with the job_wrapper decorator:

@tlab_trainer.job_wrapper(
wandb_project_name="my_project", # Optional: Set a custom Weights & Biases project name
manual_logging=False # Optional: Set to True for manual metric logging
)
def train_model():
# Your training code here
pass

The decorator parameters include:

  • progress_start and progress_end: Define the progress range (typically 0-100). These are optional fields and will typically track from 0 to 100 if not tracked.
  • wandb_project_name: Optional custom name for your Weights & Biases project. Default is TLAB_Training
  • manual_logging: Set to True for training scripts without automatic logging integration. Default is False.

Note: There is also an async version of the job wrapper available for functions which might need to run asynchronously. This can be used by just changing @tlab_trainer.job_wrapper to @tlab_trainer.async_job_wrapper.

3. Use helper methods​

Replace parts of your code with tlab_trainer helper methods:

  • For dataset loading: tlab_trainer.load_dataset()
  • For progress tracking: tlab_trainer.create_progress_callback()
  • For storing anything to the job data (optional): tlab_trainer.add_job_data(key, value)

Complete example​

Here's how a typical training script can be adapted to use tlab_trainer:

import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset

# Parse command line arguments
def parse_args():
parser = argparse.ArgumentParser(description="Train a model")
parser.add_argument("--model_name", type=str, required=True, help="Model to train")
parser.add_argument("--dataset_name", type=str, required=True, help="Dataset to use")
parser.add_argument("--output_dir", type=str, default="./output", help="Output directory")
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training")
parser.add_argument("--max_length", type=int, default=512, help="Max sequence length")
return parser.parse_args()

def train_model():
# 1. Parse arguments
args = parse_args()

# 2. Load dataset
dataset = load_dataset(args.dataset_name)["train"]

# 3. Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)

# 4. Setup training arguments
training_args = TrainingArguments(
output_dir=args.output_dir,
learning_rate=args.learning_rate,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.batch_size,
max_length=args.max_length,
# other arguments...
)

# 5. Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)

# 6. Train and save
trainer.train()
trainer.save_model(args.output_dir)

print(f"Model saved to {args.output_dir}")

# Call the function
if __name__ == "__main__":
train_model()

Adapted Script with tlab_trainer​

from transformerlab.sdk.v1.train import tlab_trainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer


@tlab_trainer.job_wrapper(progress_start=0, progress_end=100)
def train_model():
# 1. Load dataset with helper
datasets = tlab_trainer.load_dataset()
dataset = datasets["train"]

# 2. Load model and tokenizer (same as before)
model = AutoModelForCausalLM.from_pretrained(tlab_trainer.model_name)
tokenizer = AutoTokenizer.from_pretrained(tlab_trainer.model_name)

# 3. Setup training arguments with parameters from Transformer Lab
training_args = TrainingArguments(
output_dir=tlab_trainer.params.output_dir,
learning_rate=float(tlab_trainer.params.learning_rate),
num_train_epochs=int(tlab_trainer.params.num_train_epochs),
report_to=tlab_trainer.report_to,
# other arguments...
)

# 4. Create progress callback
progress_callback = tlab_trainer.create_progress_callback(framework="huggingface")

# 5. Create trainer with callback
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
callbacks=[progress_callback],
)

# 6. Train and save
trainer.train()
trainer.save_model(tlab_trainer.output_dir)

return True

# Call the function
train_model()

Key Differences​

  1. Decorator: Added @tlab_trainer.job_wrapper to wrap the function
  2. Dataset Loading: Used tlab_trainer.load_dataset() instead of direct loading
  3. Parameter Access: Accessed parameters via tlab_trainer.parameter_name or getattr(tlab_trainer, "parameter_name", default_value)
  4. Progress Tracking: Added tlab_trainer.create_progress_callback(framework="huggingface") for reporting progress
  5. Return Value: The return value could be anything, but it's recommended to return a boolean to indicate success/failure. The job wrapper will handle catching the errors and report them accordingly.

Parameter Access​

Parameters are automatically loaded from the Transformer Lab configuration. You can access them in several ways:

  1. Direct access (if sure the parameter exists): tlab_trainer.params.<parameter_name>
  2. Safe access with default (recommended): tlab_trainer.params.get(<parameter_name>, <default_value>)

Common parameters include:

  • tlab_trainer.params.model_name: Model to use for training
  • tlab_trainer.params.dataset_name: Dataset to use
  • tlab_trainer.params.output_dir: Directory for saving outputs
  • tlab_trainer.params.num_train_epochs: Number of training epochs
  • tlab_trainer.params.batch_size: Batch size for training
  • tlab_trainer.params.learning_rate: Learning rate

Progress Reporting​

Transformer Lab expects progress updates from 0 to 100. Use these methods:

  1. Create callback: Create a progress callback with tlab_trainer.create_progress_callback(framework="huggingface") and fetch it to your trainer.
  2. Manual updates: For custom loops, use tlab_trainer.progress_update(progress) where progress is 0-100

Manual Metric Logging​

For training scripts that don't have automatic integration with logging platforms like Huggingface Trainer does, you can use manual logging:

  1. Enable manual logging: Set manual_logging=True in the decorator
  2. Log metrics: Use tlab_trainer.log_metric(name, value, step) to log metrics during training

Example with a custom training loop:

@tlab_trainer.job_wrapper(manual_logging=True)
def train_model():
# Setup model, data, etc.

total_steps = 1000
for step in range(total_steps):
# Training logic here
loss = model.train_step(batch)

# Log metrics manually
tlab_trainer.log_metric("train/loss", loss.item(), step)
tlab_trainer.log_metric("train/lr", scheduler.get_last_lr()[0], step)

# Update progress
progress = (step / total_steps) * 100
tlab_trainer.progress_update(progress)

The log_metric function automatically handles logging to both Tensorboard and Weights & Biases (if enabled), so you don't need separate code paths for different logging backends.

Best Practices​

  1. Error Handling: The decorator handles basic error reporting, but include try/except blocks for specific operations
  2. Parameter Access: Always use .get() with sensible defaults for optional parameters

Summary​

By following this guide, you can quickly adapt your existing training scripts to work within the Transformer Lab ecosystem, gaining parameter management, progress tracking, and integrated logging with minimal code changes.