DPO
Documentation for Deep Policy Optimization (DPO) Module¶
Overview¶
Deep Policy Optimization (DPO) is a PyTorch module designed for optimizing policies in decision-making models. It utilizes a reference model and a trainable policy model to compute loss values that guide the learning process.
Class Definition¶
Arguments¶
Argument | Type | Description | Default |
---|---|---|---|
model |
nn.Module |
The policy model to be optimized. | - |
beta |
float |
A parameter controlling the influence of log-ratios in loss. | 0.1 |
Methods¶
forward(preferred_seq: Tensor, unpreferred_seq: Tensor) -> Tensor
¶
Computes the loss based on the difference in log probabilities between preferred and unpreferred sequences.
Arguments¶
Argument | Type | Description |
---|---|---|
preferred_seq |
Tensor |
The sequence of actions/decisions preferred. |
unpreferred_seq |
Tensor |
The sequence of actions/decisions unpreferred. |
Returns¶
A torch.Tensor
representing the computed loss.
Usage Examples¶
Example 1: Basic Setup and Usage¶
import torch
from torch import nn
from zeta.rl import DPO
# Define a simple policy model
class PolicyModel(nn.Module):
def __init__(self, dim, output_dim):
super().__init__()
self.fc = nn.Linear(dim, output_dim)
def forward(self, x):
return self.fc(x)
dim = 10
output_dim = 5
policy_model = PolicyModel(dim, output_dim)
# Initialize DPO with the policy model
dpo_model = DPO(model=policy_model, beta=0.1)
# Sample preferred and unpreferred sequences
preferred_seq = torch.randn(1, 10, 10)
unpreferred_seq = torch.randn(1, 10, 10)
# Compute loss
loss = dpo_model(preferred_seq, unpreferred_seq)
print(loss)
Example 2: Integrating with an Optimizer¶
optimizer = torch.optim.Adam(dpo_model.parameters(), lr=0.001)
# Training loop
for epoch in range(100):
optimizer.zero_grad()
loss = dpo_model(preferred_seq, unpreferred_seq)
loss.backward()
optimizer.step()
Notes¶
- Ensure that
preferred_seq
andunpreferred_seq
have the same shape and are compatible with the input dimensions of the policy model. beta
is a hyperparameter and may require tuning for different applications.- The policy model should be structured to output logits compatible with the sequences being evaluated.
This documentation provides a comprehensive guide to utilizing the DPO module in various decision-making contexts. The examples demonstrate basic usage and integration within a training loop.