Навчання RL-агента (PPO/SAC/DQN) для торгової стратегії
Три найчастіше використовуваних RL-алгоритми в алготрейдингу мають різні переваги. Вибір алгоритму залежить від архітектури стратегії: дискретний або неперервний простір дій, навчання on-policy або off-policy.
DQN (Deep Q-Network)
Підходить для: дискретні дії (купівля/утримання/продаж), прості стратегії, достатня стабільність.
DQN навчає Q-функцію: Q(state, action) — очікувана дисконтована нагорода при виборі дії в стані.
import torch
import torch.nn as nn
from collections import deque
import random
class DQNNetwork(nn.Module):
def __init__(self, state_dim, n_actions, hidden_dim=256):
super().__init__()
# Dueling архітектура: окремі потоки Value та Advantage
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.value_stream = nn.Linear(hidden_dim, 1)
self.advantage_stream = nn.Linear(hidden_dim, n_actions)
def forward(self, x):
shared = self.shared(x)
value = self.value_stream(shared)
advantage = self.advantage_stream(shared)
# Dueling: Q = V + (A - mean(A))
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
return q_values
class PrioritizedReplayBuffer:
"""Prioritized Experience Replay — частіше вибираємо важливі переходи"""
def __init__(self, capacity=50000, alpha=0.6):
self.buffer = deque(maxlen=capacity)
self.priorities = deque(maxlen=capacity)
self.alpha = alpha
def push(self, state, action, reward, next_state, done, td_error=1.0):
priority = (abs(td_error) + 1e-5) ** self.alpha
self.buffer.append((state, action, reward, next_state, done))
self.priorities.append(priority)
def sample(self, batch_size, beta=0.4):
probs = np.array(self.priorities) / sum(self.priorities)
indices = np.random.choice(len(self.buffer), batch_size, p=probs)
# Importance sampling ваги
weights = (len(self.buffer) * probs[indices]) ** (-beta)
weights /= weights.max()
batch = [self.buffer[i] for i in indices]
return batch, indices, weights
Double DQN: усуває переоцінку Q-значень. Використовуємо online мережу для вибору дії, target мережу для оцінки.
# Double DQN розрахунок цілі
with torch.no_grad():
next_actions = online_net(next_states).argmax(dim=1) # online мережа вибирає
next_q = target_net(next_states).gather(1, next_actions.unsqueeze(1)) # target оцінює
targets = rewards + gamma * next_q * (1 - dones)
PPO (Proximal Policy Optimization)
Підходить для: дискретні та неперервні дії, on-policy, стабільне навчання.
PPO обмежує розмір оновлення політики через clipping:
class PPOActor(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh()
)
self.policy_head = nn.Linear(hidden_dim, action_dim)
self.value_head = nn.Linear(hidden_dim, 1)
def forward(self, x):
features = self.network(x)
logits = self.policy_head(features)
value = self.value_head(features)
return logits, value
def ppo_update(model, optimizer, states, actions, old_log_probs,
advantages, returns, clip_eps=0.2, n_epochs=4):
for _ in range(n_epochs):
logits, values = model(states)
dist = torch.distributions.Categorical(logits=logits)
new_log_probs = dist.log_prob(actions)
entropy = dist.entropy()
# PPO clipped objective
ratio = (new_log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
actor_loss = -torch.min(surr1, surr2).mean()
critic_loss = (returns - values.squeeze()).pow(2).mean()
entropy_loss = -entropy.mean()
total_loss = actor_loss + 0.5 * critic_loss + 0.01 * entropy_loss
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
SAC (Soft Actor-Critic)
Підходить для: неперервний простір дій (позиціонування 0%–100% капіталу), off-policy, максимальна ефективність вибірки.
SAC максимізує: J(π) = E[Σ γ^t (r_t + α H(π(·|s_t)))]
Додатковий член з ентропією H спонукає агента до дослідження і запобігає передчасній конвергенції.
class SACActorContinuous(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU()
)
self.mean_head = nn.Linear(hidden_dim, action_dim)
self.log_std_head = nn.Linear(hidden_dim, action_dim)
def forward(self, x):
features = self.network(x)
mean = self.mean_head(features)
log_std = self.log_std_head(features).clamp(-20, 2)
std = log_std.exp()
dist = torch.distributions.Normal(mean, std)
action = dist.rsample() # reparameterization trick
# Стискаємо до [-1, 1]
action_tanh = torch.tanh(action)
log_prob = dist.log_prob(action) - torch.log(1 - action_tanh.pow(2) + 1e-6)
return action_tanh, log_prob.sum(-1, keepdim=True)
Порівняння алгоритмів для крипто-торгівлі
| Алгоритм | Простір дій | Ефективність вибірки | Стабільність | Найкраще застосування |
|---|---|---|---|---|
| DQN | Дискретний | Середня | Середня | Прості стратегії купівлі/продажу |
| PPO | Обидва | Низька (on-policy) | Висока | Загальне застосування, надійна |
| SAC | Неперервний | Висока | Висока | Розмір позиції як дія |
Мультиагентна торгівля
Кілька RL-агентів на різних таймфреймах:
- Macro-агент (1D): визначає загальний напрямок
- Micro-агент (1H): синхронізація входу/виходу
- Виконавчий агент (15M): оптимальне виконання
Macro-агент передає сигнал як частину стану micro-агента.
Ключові складності
Нестаціонарність ринку: агент, навчений на 2020–2021 роках, може працювати погано на 2022–2023 роках. Безперервне навчання / періодичне переналаштування обов'язкові.
Reward hacking: агент може знайти способи отримувати нагороду, що не відповідають фактичному торговому прибутку. Ретельний дизайн нагород критичний.
Переналаштування на навчальні дані: агент може запам'ятати конкретні паттерни навчального періоду. Оцінка на повністю відокремлених тестових даних.
Розроблення RL торгового агента з оптимальним вибором алгоритму (DQN/PPO/SAC) для завдання, користувацьким торговим середовищем, shaping нагород, walk-forward оцінкою та MLflow відстеженням.







