Creating a Trainer Plugin Script
This guide explains how to adapt your existing training scripts to work with Transformer Lab using the tlab_trainer
decorator class. By integrating with Transformer Lab, your training scripts gain progress tracking, parameter management, dataset handling, and integrated logging with minimal code changes. This is a part of the active development we are conducting with the Transformer Lab Plugin SDK to make integrating third-party plugins easier.
What is tlab_trainer
?​
tlab_trainer
is a decorator class that helps integrate your training script with Transformer Lab's job management system. It provides:
- Argument parsing and configuration loading
- Dataset loading helpers
- Progress tracking and reporting
- Job status management
- Integration with TensorBoard and Weights & Biases
Getting Started​
1. Import the decorator​
Add this import to your training script:
from transformerlab.sdk.v1.train import tlab_trainer
2. Decorate your main training function​
Wrap your main training function with the job_wrapper
decorator:
@tlab_trainer.job_wrapper(
wandb_project_name="my_project", # Optional: Set a custom Weights & Biases project name
manual_logging=False # Optional: Set to True for manual metric logging
)
def train_model():
# Your training code here
pass
The decorator parameters include:
progress_start
andprogress_end
: Define the progress range (typically 0-100). These are optional fields and will typically track from 0 to 100 if not tracked.wandb_project_name
: Optional custom name for your Weights & Biases project. Default isTLAB_Training
manual_logging
: Set toTrue
for training scripts without automatic logging integration. Default isFalse
.
Note: There is also an async version of the job wrapper available for functions which might need to run asynchronously. This can be used by just changing @tlab_trainer.job_wrapper
to @tlab_trainer.async_job_wrapper
.
3. Use helper methods​
Replace parts of your code with tlab_trainer
helper methods:
- For dataset loading:
tlab_trainer.load_dataset()
- For progress tracking:
tlab_trainer.create_progress_callback()
- For storing anything to the job data (optional):
tlab_trainer.add_job_data(key, value)
Complete example​
Here's how a typical training script can be adapted to use tlab_trainer
:
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset
# Parse command line arguments
def parse_args():
parser = argparse.ArgumentParser(description="Train a model")
parser.add_argument("--model_name", type=str, required=True, help="Model to train")
parser.add_argument("--dataset_name", type=str, required=True, help="Dataset to use")
parser.add_argument("--output_dir", type=str, default="./output", help="Output directory")
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training")
parser.add_argument("--max_length", type=int, default=512, help="Max sequence length")
return parser.parse_args()
def train_model():
# 1. Parse arguments
args = parse_args()
# 2. Load dataset
dataset = load_dataset(args.dataset_name)["train"]
# 3. Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# 4. Setup training arguments
training_args = TrainingArguments(
output_dir=args.output_dir,
learning_rate=args.learning_rate,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.batch_size,
max_length=args.max_length,
# other arguments...
)
# 5. Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
# 6. Train and save
trainer.train()
trainer.save_model(args.output_dir)
print(f"Model saved to {args.output_dir}")
# Call the function
if __name__ == "__main__":
train_model()
Adapted Script with tlab_trainer
​
from transformerlab.sdk.v1.train import tlab_trainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
@tlab_trainer.job_wrapper(progress_start=0, progress_end=100)
def train_model():
# 1. Load dataset with helper
datasets = tlab_trainer.load_dataset()
dataset = datasets["train"]
# 2. Load model and tokenizer (same as before)
model = AutoModelForCausalLM.from_pretrained(tlab_trainer.model_name)
tokenizer = AutoTokenizer.from_pretrained(tlab_trainer.model_name)
# 3. Setup training arguments with parameters from Transformer Lab
training_args = TrainingArguments(
output_dir=tlab_trainer.params.output_dir,
learning_rate=float(tlab_trainer.params.learning_rate),
num_train_epochs=int(tlab_trainer.params.num_train_epochs),
report_to=tlab_trainer.report_to,
# other arguments...
)
# 4. Create progress callback
progress_callback = tlab_trainer.create_progress_callback(framework="huggingface")
# 5. Create trainer with callback
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
callbacks=[progress_callback],
)
# 6. Train and save
trainer.train()
trainer.save_model(tlab_trainer.output_dir)
return True
# Call the function
train_model()
Key Differences​
- Decorator: Added
@tlab_trainer.job_wrapper
to wrap the function - Dataset Loading: Used
tlab_trainer.load_dataset()
instead of direct loading - Parameter Access: Accessed parameters via
tlab_trainer.parameter_name
orgetattr(tlab_trainer, "parameter_name", default_value)
- Progress Tracking: Added
tlab_trainer.create_progress_callback(framework="huggingface")
for reporting progress - Return Value: The return value could be anything, but it's recommended to return a boolean to indicate success/failure. The job wrapper will handle catching the errors and report them accordingly.
Parameter Access​
Parameters are automatically loaded from the Transformer Lab configuration. You can access them in several ways:
- Direct access (if sure the parameter exists):
tlab_trainer.params.<parameter_name>
- Safe access with default (recommended):
tlab_trainer.params.get(<parameter_name>, <default_value>)
Common parameters include:
tlab_trainer.params.model_name
: Model to use for trainingtlab_trainer.params.dataset_name
: Dataset to usetlab_trainer.params.output_dir
: Directory for saving outputstlab_trainer.params.num_train_epochs
: Number of training epochstlab_trainer.params.batch_size
: Batch size for trainingtlab_trainer.params.learning_rate
: Learning rate
Progress Reporting​
Transformer Lab expects progress updates from 0 to 100. Use these methods:
- Create callback: Create a progress callback with
tlab_trainer.create_progress_callback(framework="huggingface")
and fetch it to your trainer. - Manual updates: For custom loops, use
tlab_trainer.progress_update(progress)
where progress is 0-100
Manual Metric Logging​
For training scripts that don't have automatic integration with logging platforms like Huggingface Trainer does, you can use manual logging:
- Enable manual logging: Set
manual_logging=True
in the decorator - Log metrics: Use
tlab_trainer.log_metric(name, value, step)
to log metrics during training
Example with a custom training loop:
@tlab_trainer.job_wrapper(manual_logging=True)
def train_model():
# Setup model, data, etc.
total_steps = 1000
for step in range(total_steps):
# Training logic here
loss = model.train_step(batch)
# Log metrics manually
tlab_trainer.log_metric("train/loss", loss.item(), step)
tlab_trainer.log_metric("train/lr", scheduler.get_last_lr()[0], step)
# Update progress
progress = (step / total_steps) * 100
tlab_trainer.progress_update(progress)
The log_metric
function automatically handles logging to both Tensorboard and Weights & Biases (if enabled), so you don't need separate code paths for different logging backends.
Best Practices​
- Error Handling: The decorator handles basic error reporting, but include try/except blocks for specific operations
- Parameter Access: Always use
.get()
with sensible defaults for optional parameters
Summary​
By following this guide, you can quickly adapt your existing training scripts to work within the Transformer Lab ecosystem, gaining parameter management, progress tracking, and integrated logging with minimal code changes.