Trading Agent with SAC (Soft Actor-Critic)
SAC is an off-policy algorithm with maximum entropy reinforcement learning. Optimizes not only reward, but also policy entropy — the agent learns to be good AND maximally diverse. For trading, this means: don't get stuck in one strategy, better explore market regimes.
Maximum Entropy RL Principle
Standard RL: max E[R]. SAC: max E[R + α·H(π)].
H(π) = -E[log π(a|s)] — policy entropy. α — temperature (auto-tuning in SAC v2).
In practice: agent prefers two equally profitable strategies over one more stochastic. In trading: robustness against overfitting to specific market regimes.
SAC vs PPO for Trading
| Characteristic | SAC | PPO |
|---|---|---|
| Type | Off-policy | On-policy |
| Replay buffer | Yes (1M+) | No |
| Sample efficiency | High | Medium |
| Learning stability | High | High |
| Action space | Continuous (better) | Continuous/Discrete |
| Infrastructure | Harder (replay) | Easier |
SAC is preferable when: limited historical data, continuous actions (portfolio weights), need for sample-efficient learning.
SAC Architecture
Three networks:
- Policy network π_θ(a|s): Gaussian policy with reparameterization trick
- Two Q-networks Q_φ1, Q_φ2: double Q trick for reducing overestimation bias
- Target Q-networks (EMA copies): training stabilization
import torch
import torch.nn as nn
from torch.distributions import Normal
class SACPolicy(nn.Module):
def __init__(self, state_dim, action_dim, hidden=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden), nn.ReLU(),
nn.Linear(hidden, hidden), nn.ReLU()
)
self.mean_layer = nn.Linear(hidden, action_dim)
self.log_std_layer = nn.Linear(hidden, action_dim)
self.LOG_STD_MIN, self.LOG_STD_MAX = -20, 2
def forward(self, state):
feat = self.net(state)
mean = self.mean_layer(feat)
log_std = self.log_std_layer(feat).clamp(self.LOG_STD_MIN, self.LOG_STD_MAX)
std = log_std.exp()
dist = Normal(mean, std)
# reparameterization: a = tanh(mean + std * ε)
action = torch.tanh(dist.rsample())
log_prob = dist.log_prob(action).sum(-1, keepdim=True)
# correction for tanh squashing
log_prob -= torch.log(1 - action.pow(2) + 1e-6).sum(-1, keepdim=True)
return action, log_prob
Automatic Temperature α Tuning
SAC v2 removes manual α tuning. Target entropy = -dim(action_space):
target_entropy = -action_dim # for 5 assets = -5
log_alpha = torch.zeros(1, requires_grad=True)
alpha_optimizer = torch.optim.Adam([log_alpha], lr=3e-4)
# alpha loss (updated every step)
alpha_loss = -(log_alpha * (log_pi + target_entropy).detach()).mean()
alpha_optimizer.zero_grad()
alpha_loss.backward()
alpha_optimizer.step()
alpha = log_alpha.exp().item()
Replay Buffer for Financial Time Series
Standard uniform replay buffer doesn't account for temporal structure. Prioritized Experience Replay (PER): samples transitions with high TD-error more frequently.
Temporal replay buffer: stores not i.i.d. transitions, but sequences (for LSTM policy):
- Sequence length = 20 (20-day context)
- When sampling, a random continuous segment is taken
- BPTT through entire sequence
class SequenceReplayBuffer:
def __init__(self, capacity, seq_len):
self.buffer = deque(maxlen=capacity)
self.seq_len = seq_len
def sample_sequences(self, batch_size):
starts = np.random.randint(0, len(self.buffer) - self.seq_len, batch_size)
return [list(self.buffer)[s:s+self.seq_len] for s in starts]
Implementation via Stable Baselines3
from stable_baselines3 import SAC
model = SAC(
"MlpPolicy",
env,
learning_rate=3e-4,
buffer_size=1_000_000,
learning_starts=10_000, # warmup without updates
batch_size=256,
tau=0.005, # EMA for target networks
gamma=0.99,
train_freq=1,
gradient_steps=1,
ent_coef='auto', # auto-tuning α
target_entropy='auto',
verbose=1
)
model.learn(total_timesteps=500_000)
learning_starts is critical for trading: first 10K steps — random exploration without network updates. Populates replay buffer with diverse experiences.
Performance Comparison
All else being equal, SAC typically outperforms PPO by 10-15% on Sharpe Ratio due to better exploration and sample efficiency. But requires more GPU memory (replay buffer) and is more complex to debug.
Timeline: 6–10 weeks
Basic SAC on OHLCV data — 3–5 weeks. PER + sequence replay, LSTM policy, live broker connection — 8–10 weeks.







