feat: implement strategy metadata and dashboard simulation panel

- Added display_name and description to BaseStrategy
- Updated MA44 and MA125 strategies with metadata
- Added /api/v1/strategies endpoint for dynamic discovery
- Added Strategy Simulation panel to dashboard with date picker and tooltips
- Implemented JS polling for backtest results in dashboard
- Added performance test scripts and DB connection guide
- Expanded indicator config to all 15 timeframes
This commit is contained in:
BTC Bot
2026-02-13 09:50:08 +01:00
parent 38f0a21f56
commit d7bdfcf716
23 changed files with 3623 additions and 241 deletions

View File

@ -396,13 +396,173 @@
font-size: 14px;
}
@media (max-width: 1200px) {
/* Simulation Panel Styles */
.sim-strategies {
display: flex;
flex-direction: column;
gap: 4px;
margin: 8px 0;
}
.sim-strategy-option {
display: flex;
align-items: center;
gap: 8px;
padding: 8px;
border-radius: 4px;
cursor: pointer;
transition: background 0.2s;
}
.sim-strategy-option:hover {
background: var(--tv-hover);
}
.sim-strategy-option input[type="radio"] {
cursor: pointer;
}
.sim-strategy-option label {
cursor: pointer;
flex: 1;
font-size: 13px;
}
.sim-strategy-info {
color: var(--tv-text-secondary);
cursor: help;
position: relative;
font-size: 14px;
width: 20px;
height: 20px;
display: flex;
align-items: center;
justify-content: center;
border-radius: 50%;
transition: background 0.2s;
}
.sim-strategy-info:hover {
background: var(--tv-hover);
color: var(--tv-text);
}
/* Tooltip */
.sim-strategy-info:hover::after {
content: attr(data-tooltip);
position: absolute;
bottom: 100%;
right: 0;
background: var(--tv-panel-bg);
border: 1px solid var(--tv-border);
padding: 8px 12px;
border-radius: 4px;
font-size: 12px;
width: 250px;
z-index: 100;
box-shadow: 0 4px 12px rgba(0,0,0,0.3);
color: var(--tv-text);
margin-bottom: 4px;
line-height: 1.4;
pointer-events: none;
}
.sim-input-group {
margin-bottom: 12px;
}
.sim-input-group label {
display: block;
font-size: 11px;
color: var(--tv-text-secondary);
margin-bottom: 4px;
text-transform: uppercase;
}
.sim-input {
width: 100%;
padding: 8px 12px;
background: var(--tv-bg);
border: 1px solid var(--tv-border);
border-radius: 4px;
color: var(--tv-text);
font-size: 13px;
font-family: inherit;
}
.sim-input:focus {
outline: none;
border-color: var(--tv-blue);
}
.sim-run-btn {
width: 100%;
padding: 10px;
background: var(--tv-blue);
color: white;
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 13px;
font-weight: 600;
transition: opacity 0.2s;
margin-top: 8px;
}
.sim-run-btn:hover:not(:disabled) {
opacity: 0.9;
}
.sim-run-btn:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.sim-results {
margin-top: 16px;
padding-top: 16px;
border-top: 1px solid var(--tv-border);
}
.sim-stat-row {
display: flex;
justify-content: space-between;
padding: 6px 0;
font-size: 13px;
}
.sim-stat-row span:first-child {
color: var(--tv-text-secondary);
}
.sim-value {
font-weight: 600;
font-family: 'Courier New', monospace;
}
.sim-value.positive { color: var(--tv-green); }
.sim-value.negative { color: var(--tv-red); }
.loading-strategies {
padding: 12px;
text-align: center;
color: var(--tv-text-secondary);
font-size: 12px;
}
@media (max-width: 1400px) {
.ta-content {
grid-template-columns: repeat(3, 1fr);
}
}
@media (max-width: 1000px) {
.ta-content {
grid-template-columns: repeat(2, 1fr);
}
}
@media (max-width: 768px) {
@media (max-width: 600px) {
.ta-content {
grid-template-columns: 1fr;
}
@ -834,18 +994,47 @@
</div>
</div>
<div class="ta-section">
<div class="ta-section-title">Price Info</div>
<div class="ta-level">
<span class="ta-level-label">Current</span>
<span class="ta-level-value">$${data.current_price.toLocaleString()}</span>
<div class="ta-section" id="simulationPanel">
<div class="ta-section-title">Strategy Simulation</div>
<!-- Date picker -->
<div class="sim-input-group" style="margin: 0 0 8px 0;">
<label style="font-size: 10px; text-transform: uppercase; color: var(--tv-text-secondary);">Start Date:</label>
<input type="datetime-local" id="simStartDate" class="sim-input" style="margin-top: 2px;">
</div>
<div style="font-size: 12px; color: var(--tv-text-secondary); margin-top: 8px;">
Based on last 200 candles<br>
Strategy: Trend following with MA crossovers
<!-- Strategies loaded dynamically here -->
<div id="strategyList" class="sim-strategies" style="max-height: 100px; overflow-y: auto;">
<div class="loading-strategies">Loading strategies...</div>
</div>
<button class="sim-run-btn" onclick="runSimulation()" id="runSimBtn" disabled style="padding: 6px; font-size: 12px; margin-top: 6px;">
Run Simulation
</button>
<!-- Results -->
<div id="simResults" class="sim-results" style="display: none; margin-top: 8px; padding-top: 8px;">
<div class="sim-stat-row" style="padding: 2px 0; font-size: 11px;">
<span>Trades:</span>
<span id="simTrades" class="sim-value">--</span>
</div>
<div class="sim-stat-row" style="padding: 2px 0; font-size: 11px;">
<span>Win Rate:</span>
<span id="simWinRate" class="sim-value">--</span>
</div>
<div class="sim-stat-row" style="padding: 2px 0; font-size: 11px;">
<span>Total P&L:</span>
<span id="simPnL" class="sim-value">--</span>
</div>
</div>
</div>
`;
// Load strategies after simulation panel is rendered
setTimeout(() => {
loadStrategies();
setDefaultStartDate();
}, 0);
}
updateStats(candle) {
@ -891,6 +1080,226 @@
window.open(geminiUrl, '_blank');
}
// Load strategies on page load
async function loadStrategies() {
try {
console.log('Fetching strategies from API...');
// Add timeout to fetch
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 5000); // 5 second timeout
const response = await fetch('/api/v1/strategies', {
signal: controller.signal
});
clearTimeout(timeoutId);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const data = await response.json();
console.log('Strategies loaded:', data);
if (!data.strategies) {
throw new Error('Invalid response format: missing strategies array');
}
renderStrategies(data.strategies);
} catch (error) {
console.error('Error loading strategies:', error);
let errorMessage = error.message;
if (error.name === 'AbortError') {
errorMessage = 'Request timeout - API server not responding';
} else if (error.message.includes('Failed to fetch')) {
errorMessage = 'Cannot connect to API server - is it running?';
}
document.getElementById('strategyList').innerHTML =
`<div class="loading-strategies" style="color: var(--tv-red);">
${errorMessage}<br>
<small>Check console (F12) for details</small>
</div>`;
}
}
// Render strategy list
function renderStrategies(strategies) {
const container = document.getElementById('strategyList');
if (!strategies || strategies.length === 0) {
container.innerHTML = '<div class="loading-strategies">No strategies available</div>';
return;
}
container.innerHTML = strategies.map((s, index) => `
<div class="sim-strategy-option">
<input type="radio" name="strategy" id="strat_${s.id}"
value="${s.id}" ${index === 0 ? 'checked' : ''}>
<label for="strat_${s.id}">${s.name}</label>
<span class="sim-strategy-info" data-tooltip="${s.description}">ⓘ</span>
</div>
`).join('');
// Enable run button
document.getElementById('runSimBtn').disabled = false;
}
// Run simulation
async function runSimulation() {
const selectedStrategy = document.querySelector('input[name="strategy"]:checked');
if (!selectedStrategy) {
alert('Please select a strategy');
return;
}
const strategyId = selectedStrategy.value;
const startDateInput = document.getElementById('simStartDate').value;
if (!startDateInput) {
alert('Please select a start date');
return;
}
// Format date for API
const startDate = new Date(startDateInput).toISOString().split('T')[0];
// Disable button during simulation
const runBtn = document.getElementById('runSimBtn');
runBtn.disabled = true;
runBtn.textContent = 'Running...';
try {
// Trigger backtest via API
const response = await fetch('/api/v1/backtests', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
symbol: 'BTC',
intervals: [window.dashboard?.currentInterval || '1d'],
start_date: startDate
})
});
const result = await response.json();
if (result.message) {
// Show that simulation is running
runBtn.textContent = 'Running...';
// Poll for results
setTimeout(() => {
pollForBacktestResults(strategyId, startDate);
}, 2000); // Wait 2 seconds then poll
} else {
alert('Failed to start simulation');
}
} catch (error) {
console.error('Error running simulation:', error);
alert('Error running simulation: ' + error.message);
// Reset button only on error
runBtn.disabled = false;
runBtn.textContent = 'Run Simulation';
}
// Button stays as "Running..." until polling completes or times out
}
// Poll for backtest results
async function pollForBacktestResults(strategyId, startDate, attempts = 0) {
const runBtn = document.getElementById('runSimBtn');
if (attempts > 30) { // Stop after 30 attempts (60 seconds)
console.log('Backtest polling timeout');
runBtn.textContent = 'Run Simulation';
runBtn.disabled = false;
// Show timeout message in results area
const simResults = document.getElementById('simResults');
if (simResults) {
simResults.innerHTML = `
<div class="sim-stat-row" style="color: var(--tv-text-secondary); font-size: 11px; text-align: center;">
<span>Simulation timeout - no results found after 60s.<br>Check server logs or try again.</span>
</div>
`;
simResults.style.display = 'block';
}
return;
}
try {
const response = await fetch('/api/v1/backtests?limit=5');
const backtests = await response.json();
// Find the most recent backtest that matches our criteria
const recentBacktest = backtests.find(bt =>
bt.strategy && bt.strategy.includes(strategyId) ||
bt.created_at > new Date(Date.now() - 60000).toISOString() // Created in last minute
);
if (recentBacktest && recentBacktest.results) {
// Parse JSON string if needed (database stores results as text)
const parsedBacktest = {
...recentBacktest,
results: typeof recentBacktest.results === 'string'
? JSON.parse(recentBacktest.results)
: recentBacktest.results
};
// Results found! Display them
displayBacktestResults(parsedBacktest);
runBtn.textContent = 'Run Simulation';
runBtn.disabled = false;
return;
}
// No results yet, poll again in 2 seconds
setTimeout(() => {
pollForBacktestResults(strategyId, startDate, attempts + 1);
}, 2000);
} catch (error) {
console.error('Error polling for backtest results:', error);
runBtn.textContent = 'Run Simulation';
runBtn.disabled = false;
}
}
// Display backtest results in the UI
function displayBacktestResults(backtest) {
// Parse JSON string if needed (database stores results as text)
const results = typeof backtest.results === 'string'
? JSON.parse(backtest.results)
: backtest.results;
// Update the results display
document.getElementById('simTrades').textContent = results.total_trades || '--';
document.getElementById('simWinRate').textContent = results.win_rate ? results.win_rate.toFixed(1) + '%' : '--';
const pnlElement = document.getElementById('simPnL');
const pnl = results.total_pnl || 0;
pnlElement.textContent = (pnl >= 0 ? '+' : '') + '$' + pnl.toFixed(2);
pnlElement.className = 'sim-value ' + (pnl >= 0 ? 'positive' : 'negative');
// Show results section
document.getElementById('simResults').style.display = 'block';
console.log('Backtest results:', backtest);
}
// Set default start date (7 days ago)
function setDefaultStartDate() {
const startDateInput = document.getElementById('simStartDate');
if (startDateInput) {
const sevenDaysAgo = new Date();
sevenDaysAgo.setDate(sevenDaysAgo.getDate() - 7);
// Format as datetime-local: YYYY-MM-DDTHH:mm
startDateInput.value = sevenDaysAgo.toISOString().slice(0, 16);
}
}
document.addEventListener('DOMContentLoaded', () => {
window.dashboard = new TradingDashboard();
});

View File

@ -6,17 +6,28 @@ Removes the complex WebSocket manager that was causing issues
import os
import asyncio
import logging
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Query
from fastapi import FastAPI, HTTPException, Query, BackgroundTasks
from fastapi.staticfiles import StaticFiles
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import asyncpg
import csv
import io
from pydantic import BaseModel, Field
# Imports for backtest runner
from src.data_collector.database import DatabaseManager
from src.data_collector.indicator_engine import IndicatorEngine, IndicatorConfig
from src.data_collector.brain import Brain
from src.data_collector.backtester import Backtester
# Imports for strategy discovery
import importlib
from src.strategies.base import BaseStrategy
logging.basicConfig(level=logging.INFO)
@ -88,6 +99,41 @@ async def root():
}
@app.get("/api/v1/strategies")
async def list_strategies():
"""List all available trading strategies with metadata"""
# Strategy registry from brain.py
strategy_registry = {
"ma44_strategy": "src.strategies.ma44_strategy.MA44Strategy",
"ma125_strategy": "src.strategies.ma125_strategy.MA125Strategy",
}
strategies = []
for strategy_id, class_path in strategy_registry.items():
try:
module_path, class_name = class_path.rsplit('.', 1)
module = importlib.import_module(module_path)
strategy_class = getattr(module, class_name)
# Instantiate to get metadata
strategy_instance = strategy_class()
strategies.append({
"id": strategy_id,
"name": strategy_instance.display_name,
"description": strategy_instance.description,
"required_indicators": strategy_instance.required_indicators
})
except Exception as e:
logger.error(f"Failed to load strategy {strategy_id}: {e}")
return {
"strategies": strategies,
"count": len(strategies)
}
@app.get("/api/v1/candles")
async def get_candles(
symbol: str = Query("BTC", description="Trading pair symbol"),
@ -215,6 +261,155 @@ async def health_check():
raise HTTPException(status_code=503, detail=f"Health check failed: {str(e)}")
@app.get("/api/v1/indicators")
async def get_indicators(
symbol: str = Query("BTC", description="Trading pair symbol"),
interval: str = Query("1d", description="Candle interval"),
name: str = Query(None, description="Filter by indicator name (e.g., ma44)"),
start: Optional[datetime] = Query(None, description="Start time"),
end: Optional[datetime] = Query(None, description="End time"),
limit: int = Query(1000, le=5000)
):
"""Get indicator values"""
async with pool.acquire() as conn:
query = """
SELECT time, indicator_name, value
FROM indicators
WHERE symbol = $1 AND interval = $2
"""
params = [symbol, interval]
if name:
query += f" AND indicator_name = ${len(params) + 1}"
params.append(name)
if start:
query += f" AND time >= ${len(params) + 1}"
params.append(start)
if end:
query += f" AND time <= ${len(params) + 1}"
params.append(end)
query += f" ORDER BY time DESC LIMIT ${len(params) + 1}"
params.append(limit)
rows = await conn.fetch(query, *params)
# Group by time for easier charting
grouped = {}
for row in rows:
ts = row['time'].isoformat()
if ts not in grouped:
grouped[ts] = {'time': ts}
grouped[ts][row['indicator_name']] = float(row['value'])
return {
"symbol": symbol,
"interval": interval,
"data": list(grouped.values())
}
@app.get("/api/v1/decisions")
async def get_decisions(
symbol: str = Query("BTC"),
interval: Optional[str] = Query(None),
backtest_id: Optional[str] = Query(None),
limit: int = Query(100, le=1000)
):
"""Get brain decisions"""
async with pool.acquire() as conn:
query = """
SELECT time, interval, decision_type, strategy, confidence,
price_at_decision, indicator_snapshot, reasoning, backtest_id
FROM decisions
WHERE symbol = $1
"""
params = [symbol]
if interval:
query += f" AND interval = ${len(params) + 1}"
params.append(interval)
if backtest_id:
query += f" AND backtest_id = ${len(params) + 1}"
params.append(backtest_id)
else:
query += " AND backtest_id IS NULL"
query += f" ORDER BY time DESC LIMIT ${len(params) + 1}"
params.append(limit)
rows = await conn.fetch(query, *params)
return [dict(row) for row in rows]
@app.get("/api/v1/backtests")
async def list_backtests(symbol: Optional[str] = None, limit: int = 20):
"""List historical backtests"""
async with pool.acquire() as conn:
query = """
SELECT id, strategy, symbol, start_time, end_time,
intervals, results, created_at
FROM backtest_runs
"""
params = []
if symbol:
query += " WHERE symbol = $1"
params.append(symbol)
query += f" ORDER BY created_at DESC LIMIT ${len(params) + 1}"
params.append(limit)
rows = await conn.fetch(query, *params)
return [dict(row) for row in rows]
class BacktestRequest(BaseModel):
symbol: str = "BTC"
intervals: list[str] = ["37m"]
start_date: str = "2025-01-01" # ISO date
end_date: Optional[str] = None
async def run_backtest_task(req: BacktestRequest):
"""Background task to run backtest"""
db = DatabaseManager(
host=DB_HOST, port=DB_PORT, database=DB_NAME,
user=DB_USER, password=DB_PASSWORD
)
await db.connect()
try:
# Load configs (hardcoded for now to match main.py)
configs = [
IndicatorConfig("ma44", "sma", 44, req.intervals),
IndicatorConfig("ma125", "sma", 125, req.intervals)
]
engine = IndicatorEngine(db, configs)
brain = Brain(db, engine)
backtester = Backtester(db, engine, brain)
start = datetime.fromisoformat(req.start_date).replace(tzinfo=timezone.utc)
end = datetime.fromisoformat(req.end_date).replace(tzinfo=timezone.utc) if req.end_date else datetime.now(timezone.utc)
await backtester.run(req.symbol, req.intervals, start, end)
except Exception as e:
logger.error(f"Backtest failed: {e}")
finally:
await db.disconnect()
@app.post("/api/v1/backtests")
async def trigger_backtest(req: BacktestRequest, background_tasks: BackgroundTasks):
"""Start a backtest in the background"""
background_tasks.add_task(run_backtest_task, req)
return {"message": "Backtest started", "params": req.dict()}
@app.get("/api/v1/ta")
async def get_technical_analysis(
symbol: str = Query("BTC", description="Trading pair symbol"),
@ -222,42 +417,44 @@ async def get_technical_analysis(
):
"""
Get technical analysis for a symbol
Calculates MA 44, MA 125, trend, support/resistance
Uses stored indicators from DB if available, falls back to on-the-fly calc
"""
try:
async with pool.acquire() as conn:
# Get enough candles for MA 125 calculation
rows = await conn.fetch("""
SELECT time, open, high, low, close, volume
# 1. Get latest price
latest = await conn.fetchrow("""
SELECT close, time
FROM candles
WHERE symbol = $1 AND interval = $2
ORDER BY time DESC
LIMIT 200
LIMIT 1
""", symbol, interval)
if len(rows) < 50:
return {
"symbol": symbol,
"interval": interval,
"error": "Not enough data for technical analysis",
"min_required": 50,
"available": len(rows)
}
if not latest:
return {"error": "No candle data found"}
current_price = float(latest['close'])
timestamp = latest['time']
# Reverse to chronological order
candles = list(reversed(rows))
closes = [float(c['close']) for c in candles]
# 2. Get latest indicators from DB
indicators = await conn.fetch("""
SELECT indicator_name, value
FROM indicators
WHERE symbol = $1 AND interval = $2
AND time <= $3
ORDER BY time DESC
""", symbol, interval, timestamp)
# Calculate Moving Averages
def calculate_ma(data, period):
if len(data) < period:
return None
return sum(data[-period:]) / period
# Convert list to dict, e.g. {'ma44': 65000, 'ma125': 64000}
# We take the most recent value for each indicator
ind_map = {}
for row in indicators:
name = row['indicator_name']
if name not in ind_map:
ind_map[name] = float(row['value'])
ma_44 = calculate_ma(closes, 44)
ma_125 = calculate_ma(closes, 125)
current_price = closes[-1]
ma_44 = ind_map.get('ma44')
ma_125 = ind_map.get('ma125')
# Determine trend
if ma_44 and ma_125:
@ -274,24 +471,35 @@ async def get_technical_analysis(
trend = "Unknown"
trend_strength = "Insufficient data"
# Find support and resistance (recent swing points)
highs = [float(c['high']) for c in candles[-20:]]
lows = [float(c['low']) for c in candles[-20:]]
# 3. Find support/resistance (simple recent high/low)
rows = await conn.fetch("""
SELECT high, low
FROM candles
WHERE symbol = $1 AND interval = $2
ORDER BY time DESC
LIMIT 20
""", symbol, interval)
resistance = max(highs)
support = min(lows)
# Calculate price position
price_range = resistance - support
if price_range > 0:
position = (current_price - support) / price_range * 100
if rows:
highs = [float(r['high']) for r in rows]
lows = [float(r['low']) for r in rows]
resistance = max(highs)
support = min(lows)
price_range = resistance - support
if price_range > 0:
position = (current_price - support) / price_range * 100
else:
position = 50
else:
resistance = current_price
support = current_price
position = 50
return {
"symbol": symbol,
"interval": interval,
"timestamp": datetime.utcnow().isoformat(),
"timestamp": timestamp.isoformat(),
"current_price": round(current_price, 2),
"moving_averages": {
"ma_44": round(ma_44, 2) if ma_44 else None,

View File

@ -4,6 +4,9 @@ from .candle_buffer import CandleBuffer
from .database import DatabaseManager
from .backfill import HyperliquidBackfill
from .custom_timeframe_generator import CustomTimeframeGenerator
from .indicator_engine import IndicatorEngine, IndicatorConfig
from .brain import Brain, Decision
from .backtester import Backtester
__all__ = [
'HyperliquidWebSocket',
@ -11,5 +14,10 @@ __all__ = [
'CandleBuffer',
'DatabaseManager',
'HyperliquidBackfill',
'CustomTimeframeGenerator'
]
'CustomTimeframeGenerator',
'IndicatorEngine',
'IndicatorConfig',
'Brain',
'Decision',
'Backtester'
]

View File

@ -0,0 +1,391 @@
"""
Backtester - Historical replay driver for IndicatorEngine + Brain
Iterates over stored candle data to simulate live trading decisions
"""
import asyncio
import json
import logging
from datetime import datetime, timezone
from typing import Dict, List, Optional, Any
from uuid import uuid4
from .database import DatabaseManager
from .indicator_engine import IndicatorEngine, IndicatorConfig
from .brain import Brain, Decision
from .simulator import Account
from src.strategies.base import SignalType
logger = logging.getLogger(__name__)
class Backtester:
"""
Replays historical candle data through IndicatorEngine and Brain.
Uses Simulator (Account) to track PnL, leverage, and fees.
"""
def __init__(
self,
db: DatabaseManager,
indicator_engine: IndicatorEngine,
brain: Brain,
):
self.db = db
self.indicator_engine = indicator_engine
self.brain = brain
self.account = Account(initial_balance=1000.0)
async def run(
self,
symbol: str,
intervals: List[str],
start: datetime,
end: datetime,
config: Optional[Dict[str, Any]] = None,
) -> str:
"""
Run a full backtest over the given time range.
"""
backtest_id = str(uuid4())
logger.info(
f"Starting backtest {backtest_id}: {symbol} "
f"{intervals} from {start} to {end}"
)
# Reset brain state
self.brain.reset_state()
# Reset account for this run
self.account = Account(initial_balance=1000.0)
# Store the run metadata
await self._save_run_start(
backtest_id, symbol, intervals, start, end, config
)
total_decisions = 0
for interval in intervals:
# Only process intervals that have indicators configured
configured = self.indicator_engine.get_configured_intervals()
if interval not in configured:
logger.warning(
f"Skipping interval {interval}: no indicators configured"
)
continue
# Get all candle timestamps in range
timestamps = await self._get_candle_timestamps(
symbol, interval, start, end
)
if not timestamps:
logger.warning(
f"No candles found for {symbol}/{interval} in range"
)
continue
logger.info(
f"Backtest {backtest_id}: processing {len(timestamps)} "
f"{interval} candles..."
)
for i, ts in enumerate(timestamps):
# 1. Compute indicators
raw_indicators = await self.indicator_engine.compute_at(
symbol, interval, ts
)
indicators = {k: v for k, v in raw_indicators.items() if v is not None}
# 2. Get Current Position info for Strategy
current_pos = self.account.get_position_dict()
# 3. Brain Evaluate
decision: Decision = await self.brain.evaluate(
symbol=symbol,
interval=interval,
timestamp=ts,
indicators=indicators,
backtest_id=backtest_id,
current_position=current_pos
)
# 4. Execute Decision in Simulator
self._execute_decision(decision)
total_decisions += 1
if (i + 1) % 200 == 0:
logger.info(
f"Backtest {backtest_id}: {i + 1}/{len(timestamps)} "
f"{interval} candles processed. Eq: {self.account.equity:.2f}"
)
await asyncio.sleep(0.01)
# Compute and store summary results from Simulator
results = self.account.get_stats()
results['total_evaluations'] = total_decisions
await self._save_run_results(backtest_id, results)
logger.info(
f"Backtest {backtest_id} complete. Final Balance: {results['final_balance']:.2f}"
)
return backtest_id
def _execute_decision(self, decision: Decision):
"""Translate Brain decision into Account action"""
price = decision.price_at_decision
time = decision.time
# Open Long
if decision.decision_type == SignalType.OPEN_LONG.value:
self.account.open_position(time, 'long', price, leverage=1.0) # Todo: Configurable leverage
# Open Short
elif decision.decision_type == SignalType.OPEN_SHORT.value:
self.account.open_position(time, 'short', price, leverage=1.0)
# Close Long (only if we are long)
elif decision.decision_type == SignalType.CLOSE_LONG.value:
if self.account.current_position and self.account.current_position.side == 'long':
self.account.close_position(time, price)
# Close Short (only if we are short)
elif decision.decision_type == SignalType.CLOSE_SHORT.value:
if self.account.current_position and self.account.current_position.side == 'short':
self.account.close_position(time, price)
# Update equity mark-to-market
self.account.update_equity(price)
async def _get_candle_timestamps(
self,
symbol: str,
interval: str,
start: datetime,
end: datetime,
) -> List[datetime]:
"""Get all candle timestamps in a range, ordered chronologically"""
async with self.db.acquire() as conn:
rows = await conn.fetch("""
SELECT time FROM candles
WHERE symbol = $1 AND interval = $2
AND time >= $3 AND time <= $4
ORDER BY time ASC
""", symbol, interval, start, end)
return [row["time"] for row in rows]
async def _save_run_start(
self,
backtest_id: str,
symbol: str,
intervals: List[str],
start: datetime,
end: datetime,
config: Optional[Dict[str, Any]],
) -> None:
"""Store backtest run metadata at start"""
async with self.db.acquire() as conn:
await conn.execute("""
INSERT INTO backtest_runs (
id, strategy, symbol, start_time, end_time,
intervals, config
)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
backtest_id,
self.brain.strategy_name,
symbol,
start,
end,
intervals,
json.dumps(config) if config else None,
)
async def _compute_results(self, backtest_id, symbol):
"""Deprecated: Logic moved to Account class"""
return {}
async def _save_run_results(
self,
backtest_id: str,
results: Dict[str, Any],
) -> None:
"""Update backtest run with final results"""
# Remove trades list from stored results (can be large)
stored_results = {k: v for k, v in results.items() if k != "trades"}
async with self.db.acquire() as conn:
await conn.execute("""
UPDATE backtest_runs
SET results = $1
WHERE id = $2
""", json.dumps(stored_results), backtest_id)
async def get_run(self, backtest_id: str) -> Optional[Dict[str, Any]]:
"""Get a specific backtest run with results"""
async with self.db.acquire() as conn:
row = await conn.fetchrow("""
SELECT id, strategy, symbol, start_time, end_time,
intervals, config, results, created_at
FROM backtest_runs
WHERE id = $1
""", backtest_id)
return dict(row) if row else None
async def list_runs(
self,
symbol: Optional[str] = None,
limit: int = 20,
) -> List[Dict[str, Any]]:
"""List recent backtest runs"""
async with self.db.acquire() as conn:
if symbol:
rows = await conn.fetch("""
SELECT id, strategy, symbol, start_time, end_time,
intervals, results, created_at
FROM backtest_runs
WHERE symbol = $1
ORDER BY created_at DESC
LIMIT $2
""", symbol, limit)
else:
rows = await conn.fetch("""
SELECT id, strategy, symbol, start_time, end_time,
intervals, results, created_at
FROM backtest_runs
ORDER BY created_at DESC
LIMIT $1
""", limit)
return [dict(row) for row in rows]
async def cleanup_run(self, backtest_id: str) -> int:
"""Delete all decisions and metadata for a backtest run"""
async with self.db.acquire() as conn:
result = await conn.execute("""
DELETE FROM decisions WHERE backtest_id = $1
""", backtest_id)
await conn.execute("""
DELETE FROM backtest_runs WHERE id = $1
""", backtest_id)
deleted = int(result.split()[-1]) if result else 0
logger.info(
f"Cleaned up backtest {backtest_id}: "
f"{deleted} decisions deleted"
)
return deleted
async def main():
"""CLI entry point for running backtests"""
import argparse
import os
parser = argparse.ArgumentParser(
description="Run backtest on historical data"
)
parser.add_argument(
"--symbol", default="BTC", help="Symbol (default: BTC)"
)
parser.add_argument(
"--intervals", nargs="+", default=["37m"],
help="Intervals to backtest (default: 37m)"
)
parser.add_argument(
"--start", required=True,
help="Start date (ISO format, e.g., 2025-01-01)"
)
parser.add_argument(
"--end", default=None,
help="End date (ISO format, default: now)"
)
parser.add_argument(
"--db-host", default=os.getenv("DB_HOST", "localhost"),
)
parser.add_argument(
"--db-port", type=int, default=int(os.getenv("DB_PORT", 5432)),
)
parser.add_argument(
"--db-name", default=os.getenv("DB_NAME", "btc_data"),
)
parser.add_argument(
"--db-user", default=os.getenv("DB_USER", "btc_bot"),
)
parser.add_argument(
"--db-password", default=os.getenv("DB_PASSWORD", ""),
)
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
# Parse dates
start = datetime.fromisoformat(args.start).replace(tzinfo=timezone.utc)
end = (
datetime.fromisoformat(args.end).replace(tzinfo=timezone.utc)
if args.end
else datetime.now(timezone.utc)
)
# Initialize components
db = DatabaseManager(
host=args.db_host,
port=args.db_port,
database=args.db_name,
user=args.db_user,
password=args.db_password,
)
await db.connect()
try:
# Default indicator configs (MA44 + MA125 on selected intervals)
configs = [
IndicatorConfig("ma44", "sma", 44, args.intervals),
IndicatorConfig("ma125", "sma", 125, args.intervals),
]
indicator_engine = IndicatorEngine(db, configs)
brain = Brain(db, indicator_engine)
backtester = Backtester(db, indicator_engine, brain)
# Run the backtest
backtest_id = await backtester.run(
symbol=args.symbol,
intervals=args.intervals,
start=start,
end=end,
)
# Print results
run = await backtester.get_run(backtest_id)
if run and run.get("results"):
results = json.loads(run["results"]) if isinstance(run["results"], str) else run["results"]
print("\n=== Backtest Results ===")
print(f"ID: {backtest_id}")
print(f"Strategy: {run['strategy']}")
print(f"Period: {run['start_time']} to {run['end_time']}")
print(f"Intervals: {run['intervals']}")
print(f"Total evaluations: {results.get('total_evaluations', 0)}")
print(f"Total trades: {results.get('total_trades', 0)}")
print(f"Win rate: {results.get('win_rate', 0)}%")
print(f"Total P&L: {results.get('total_pnl_pct', 0)}%")
print(f"Final Balance: {results.get('final_balance', 0)}")
finally:
await db.disconnect()
if __name__ == "__main__":
asyncio.run(main())

223
src/data_collector/brain.py Normal file
View File

@ -0,0 +1,223 @@
"""
Brain - Strategy evaluation and decision logging
Pure strategy logic separated from DB I/O for testability
"""
import json
import logging
from dataclasses import dataclass, asdict
from datetime import datetime, timezone
from typing import Dict, Optional, Any, List
import importlib
from .database import DatabaseManager
from .indicator_engine import IndicatorEngine
from src.strategies.base import BaseStrategy, StrategySignal, SignalType
logger = logging.getLogger(__name__)
# Registry of available strategies
STRATEGY_REGISTRY = {
"ma44_strategy": "src.strategies.ma44_strategy.MA44Strategy",
"ma125_strategy": "src.strategies.ma125_strategy.MA125Strategy",
}
def load_strategy(strategy_name: str) -> BaseStrategy:
"""Dynamically load a strategy class"""
if strategy_name not in STRATEGY_REGISTRY:
# Default fallback or error
logger.warning(f"Strategy {strategy_name} not found, defaulting to MA44")
strategy_name = "ma44_strategy"
module_path, class_name = STRATEGY_REGISTRY[strategy_name].rsplit('.', 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
return cls()
@dataclass
class Decision:
"""A single brain evaluation result"""
time: datetime
symbol: str
interval: str
decision_type: str # "buy", "sell", "hold" -> Now maps to SignalType
strategy: str
confidence: float
price_at_decision: float
indicator_snapshot: Dict[str, Any]
candle_snapshot: Dict[str, Any]
reasoning: str
backtest_id: Optional[str] = None
def to_db_tuple(self) -> tuple:
"""Convert to positional tuple for DB insert"""
return (
self.time,
self.symbol,
self.interval,
self.decision_type,
self.strategy,
self.confidence,
self.price_at_decision,
json.dumps(self.indicator_snapshot),
json.dumps(self.candle_snapshot),
self.reasoning,
self.backtest_id,
)
class Brain:
"""
Evaluates market conditions using a loaded Strategy.
"""
def __init__(
self,
db: DatabaseManager,
indicator_engine: IndicatorEngine,
strategy: str = "ma44_strategy",
):
self.db = db
self.indicator_engine = indicator_engine
self.strategy_name = strategy
self.active_strategy: BaseStrategy = load_strategy(strategy)
logger.info(f"Brain initialized with strategy: {self.active_strategy.name}")
async def evaluate(
self,
symbol: str,
interval: str,
timestamp: datetime,
indicators: Optional[Dict[str, float]] = None,
backtest_id: Optional[str] = None,
current_position: Optional[Dict[str, Any]] = None,
) -> Decision:
"""
Evaluate market conditions and produce a decision.
"""
# Get indicator values
if indicators is None:
indicators = await self.indicator_engine.get_values_at(
symbol, interval, timestamp
)
# Get the triggering candle
candle = await self._get_candle(symbol, interval, timestamp)
if not candle:
return self._create_empty_decision(timestamp, symbol, interval, indicators, backtest_id)
price = float(candle["close"])
candle_dict = {
"time": candle["time"].isoformat(),
"open": float(candle["open"]),
"high": float(candle["high"]),
"low": float(candle["low"]),
"close": price,
"volume": float(candle["volume"]),
}
# Delegate to Strategy
signal: StrategySignal = self.active_strategy.analyze(
candle_dict, indicators, current_position
)
# Build decision
decision = Decision(
time=timestamp,
symbol=symbol,
interval=interval,
decision_type=signal.type.value,
strategy=self.strategy_name,
confidence=signal.confidence,
price_at_decision=price,
indicator_snapshot=indicators,
candle_snapshot=candle_dict,
reasoning=signal.reasoning,
backtest_id=backtest_id,
)
# Store to DB
await self._store_decision(decision)
return decision
def _create_empty_decision(self, timestamp, symbol, interval, indicators, backtest_id):
return Decision(
time=timestamp,
symbol=symbol,
interval=interval,
decision_type="hold",
strategy=self.strategy_name,
confidence=0.0,
price_at_decision=0.0,
indicator_snapshot=indicators or {},
candle_snapshot={},
reasoning="No candle data available",
backtest_id=backtest_id,
)
async def _get_candle(
self,
symbol: str,
interval: str,
timestamp: datetime,
) -> Optional[Dict[str, Any]]:
"""Fetch a specific candle from the database"""
async with self.db.acquire() as conn:
row = await conn.fetchrow("""
SELECT time, open, high, low, close, volume
FROM candles
WHERE symbol = $1 AND interval = $2 AND time = $3
""", symbol, interval, timestamp)
return dict(row) if row else None
async def _store_decision(self, decision: Decision) -> None:
"""Write decision to the decisions table"""
# Note: We might want to skip writing every single HOLD to DB to save space if simulating millions of candles
# But keeping it for now for full traceability
async with self.db.acquire() as conn:
await conn.execute("""
INSERT INTO decisions (
time, symbol, interval, decision_type, strategy,
confidence, price_at_decision, indicator_snapshot,
candle_snapshot, reasoning, backtest_id
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
""", *decision.to_db_tuple())
async def get_recent_decisions(
self,
symbol: str,
limit: int = 20,
backtest_id: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Get recent decisions, optionally filtered by backtest_id"""
async with self.db.acquire() as conn:
if backtest_id is not None:
rows = await conn.fetch("""
SELECT time, symbol, interval, decision_type, strategy,
confidence, price_at_decision, indicator_snapshot,
candle_snapshot, reasoning, backtest_id
FROM decisions
WHERE symbol = $1 AND backtest_id = $2
ORDER BY time DESC
LIMIT $3
""", symbol, backtest_id, limit)
else:
rows = await conn.fetch("""
SELECT time, symbol, interval, decision_type, strategy,
confidence, price_at_decision, indicator_snapshot,
candle_snapshot, reasoning, backtest_id
FROM decisions
WHERE symbol = $1 AND backtest_id IS NULL
ORDER BY time DESC
LIMIT $2
""", symbol, limit)
return [dict(row) for row in rows]
def reset_state(self) -> None:
"""Reset internal state tracking"""
pass

View File

@ -0,0 +1,285 @@
"""
Indicator Engine - Computes and stores technical indicators
Stateless DB-backed design: same code for live updates and backtesting
"""
import asyncio
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Dict, List, Optional, Any
from .database import DatabaseManager
logger = logging.getLogger(__name__)
@dataclass
class IndicatorConfig:
"""Configuration for a single indicator"""
name: str # e.g., "ma44"
type: str # e.g., "sma"
period: int # e.g., 44
intervals: List[str] # e.g., ["37m", "148m", "1d"]
@classmethod
def from_dict(cls, name: str, data: Dict[str, Any]) -> "IndicatorConfig":
"""Create config from YAML dict entry"""
return cls(
name=name,
type=data["type"],
period=data["period"],
intervals=data["intervals"],
)
@dataclass
class IndicatorResult:
"""Result of a single indicator computation"""
name: str
value: Optional[float]
period: int
timestamp: datetime
class IndicatorEngine:
"""
Computes technical indicators from candle data in the database.
Two modes, same math:
- on_interval_update(): called by live system after higher-TF candle update
- compute_at(): called by backtester for a specific point in time
Both query the DB for the required candle history and store results.
"""
def __init__(self, db: DatabaseManager, configs: List[IndicatorConfig]):
self.db = db
self.configs = configs
# Build lookup: interval -> list of configs that need computation
self._interval_configs: Dict[str, List[IndicatorConfig]] = {}
for cfg in configs:
for interval in cfg.intervals:
if interval not in self._interval_configs:
self._interval_configs[interval] = []
self._interval_configs[interval].append(cfg)
logger.info(
f"IndicatorEngine initialized with {len(configs)} indicators "
f"across intervals: {list(self._interval_configs.keys())}"
)
def get_configured_intervals(self) -> List[str]:
"""Return all intervals that have indicators configured"""
return list(self._interval_configs.keys())
async def on_interval_update(
self,
symbol: str,
interval: str,
timestamp: datetime,
) -> Dict[str, Optional[float]]:
"""
Compute all indicators configured for this interval.
Called by main.py after CustomTimeframeGenerator updates a higher TF.
Returns dict of indicator_name -> value (for use by Brain).
"""
configs = self._interval_configs.get(interval, [])
if not configs:
return {}
return await self._compute_and_store(symbol, interval, timestamp, configs)
async def compute_at(
self,
symbol: str,
interval: str,
timestamp: datetime,
) -> Dict[str, Optional[float]]:
"""
Compute indicators at a specific point in time.
Alias for on_interval_update -- used by backtester for clarity.
"""
return await self.on_interval_update(symbol, interval, timestamp)
async def compute_historical(
self,
symbol: str,
interval: str,
start: datetime,
end: datetime,
) -> int:
"""
Batch-compute indicators for a time range.
Iterates over every candle timestamp in [start, end] and computes.
Returns total number of indicator values stored.
"""
configs = self._interval_configs.get(interval, [])
if not configs:
logger.warning(f"No indicators configured for interval {interval}")
return 0
# Get all candle timestamps in range
async with self.db.acquire() as conn:
rows = await conn.fetch("""
SELECT time FROM candles
WHERE symbol = $1 AND interval = $2
AND time >= $3 AND time <= $4
ORDER BY time ASC
""", symbol, interval, start, end)
if not rows:
logger.warning(f"No candles found for {symbol}/{interval} in range")
return 0
timestamps = [row["time"] for row in rows]
total_stored = 0
logger.info(
f"Computing {len(configs)} indicators across "
f"{len(timestamps)} {interval} candles..."
)
for i, ts in enumerate(timestamps):
results = await self._compute_and_store(symbol, interval, ts, configs)
total_stored += sum(1 for v in results.values() if v is not None)
if (i + 1) % 100 == 0:
logger.info(f"Progress: {i + 1}/{len(timestamps)} candles processed")
await asyncio.sleep(0.01) # Yield to event loop
logger.info(
f"Historical compute complete: {total_stored} indicator values "
f"stored for {interval}"
)
return total_stored
async def _compute_and_store(
self,
symbol: str,
interval: str,
timestamp: datetime,
configs: List[IndicatorConfig],
) -> Dict[str, Optional[float]]:
"""Core computation: fetch candles, compute indicators, store results"""
# Determine max lookback needed
max_period = max(cfg.period for cfg in configs)
# Fetch enough candles for the longest indicator
async with self.db.acquire() as conn:
rows = await conn.fetch("""
SELECT time, open, high, low, close, volume
FROM candles
WHERE symbol = $1 AND interval = $2
AND time <= $3
ORDER BY time DESC
LIMIT $4
""", symbol, interval, timestamp, max_period)
if not rows:
return {cfg.name: None for cfg in configs}
# Reverse to chronological order
candles = list(reversed(rows))
closes = [float(c["close"]) for c in candles]
# Compute each indicator
results: Dict[str, Optional[float]] = {}
values_to_store: List[tuple] = []
for cfg in configs:
value = self._compute_indicator(cfg, closes)
results[cfg.name] = value
if value is not None:
values_to_store.append((
timestamp,
symbol,
interval,
cfg.name,
value,
json.dumps({"type": cfg.type, "period": cfg.period}),
))
# Batch upsert all computed values
if values_to_store:
async with self.db.acquire() as conn:
await conn.executemany("""
INSERT INTO indicators (time, symbol, interval, indicator_name, value, parameters)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (time, symbol, interval, indicator_name)
DO UPDATE SET
value = EXCLUDED.value,
parameters = EXCLUDED.parameters,
computed_at = NOW()
""", values_to_store)
logger.debug(
f"Stored {len(values_to_store)} indicator values for "
f"{symbol}/{interval} at {timestamp}"
)
return results
def _compute_indicator(
self,
config: IndicatorConfig,
closes: List[float],
) -> Optional[float]:
"""Dispatch to the correct computation function"""
if config.type == "sma":
return self.compute_sma(closes, config.period)
else:
logger.warning(f"Unknown indicator type: {config.type}")
return None
# ── Pure math functions (no DB, no async, easily testable) ──────────
@staticmethod
def compute_sma(closes: List[float], period: int) -> Optional[float]:
"""Simple Moving Average over the last `period` closes"""
if len(closes) < period:
return None
return sum(closes[-period:]) / period
async def get_latest_values(
self,
symbol: str,
interval: str,
) -> Dict[str, float]:
"""
Get the most recent indicator values for a symbol/interval.
Used by Brain to read current state.
"""
async with self.db.acquire() as conn:
rows = await conn.fetch("""
SELECT DISTINCT ON (indicator_name)
indicator_name, value, time
FROM indicators
WHERE symbol = $1 AND interval = $2
ORDER BY indicator_name, time DESC
""", symbol, interval)
return {row["indicator_name"]: float(row["value"]) for row in rows}
async def get_values_at(
self,
symbol: str,
interval: str,
timestamp: datetime,
) -> Dict[str, float]:
"""
Get indicator values at a specific timestamp.
Used by Brain during backtesting.
"""
async with self.db.acquire() as conn:
rows = await conn.fetch("""
SELECT indicator_name, value
FROM indicators
WHERE symbol = $1 AND interval = $2 AND time = $3
""", symbol, interval, timestamp)
return {row["indicator_name"]: float(row["value"]) for row in rows}

View File

@ -1,6 +1,6 @@
"""
Main entry point for data collector service
Integrates WebSocket client, buffer, and database
Integrates WebSocket client, buffer, database, indicators, and brain
"""
import asyncio
@ -8,13 +8,17 @@ import logging
import signal
import sys
from datetime import datetime, timezone
from typing import Optional
from typing import Optional, List
import os
import yaml
from .websocket_client import HyperliquidWebSocket, Candle
from .candle_buffer import CandleBuffer
from .database import DatabaseManager
from .custom_timeframe_generator import CustomTimeframeGenerator
from .indicator_engine import IndicatorEngine, IndicatorConfig
from .brain import Brain
# Configure logging
@ -68,6 +72,17 @@ class DataCollector:
self.custom_tf_generator = CustomTimeframeGenerator(self.db)
await self.custom_tf_generator.initialize()
# Initialize indicator engine
# Hardcoded config for now, eventually load from yaml
indicator_configs = [
IndicatorConfig("ma44", "sma", 44, ["37m", "148m", "1d"]),
IndicatorConfig("ma125", "sma", 125, ["37m", "148m", "1d"])
]
self.indicator_engine = IndicatorEngine(self.db, indicator_configs)
# Initialize brain
self.brain = Brain(self.db, self.indicator_engine)
# Initialize buffer
self.buffer = CandleBuffer(
max_size=1000,
@ -166,12 +181,47 @@ class DataCollector:
raise # Re-raise to trigger buffer retry
async def _update_custom_timeframes(self, candles: list) -> None:
"""Update custom timeframes in background (non-blocking)"""
"""
Update custom timeframes in background, then trigger indicators/brain.
This chain ensures that indicators are computed on fresh candle data,
and the brain evaluates on fresh indicator data.
"""
try:
# 1. Update custom candles (37m, 148m, etc.)
await self.custom_tf_generator.update_realtime(candles)
logger.debug("Custom timeframes updated")
# 2. Trigger indicator updates for configured intervals
# We use the timestamp of the last 1m candle as the trigger point
trigger_time = candles[-1].time
if self.indicator_engine:
intervals = self.indicator_engine.get_configured_intervals()
for interval in intervals:
# Get the correct bucket start time for this interval
# e.g., if trigger_time is 09:48:00, 37m bucket might start at 09:25:00
if self.custom_tf_generator:
bucket_start = self.custom_tf_generator.get_bucket_start(trigger_time, interval)
else:
bucket_start = trigger_time
# Compute indicators for this bucket
raw_indicators = await self.indicator_engine.on_interval_update(
self.symbol, interval, bucket_start
)
# Filter out None values to satisfy type checker
indicators = {k: v for k, v in raw_indicators.items() if v is not None}
# 3. Evaluate brain if we have fresh indicators
if self.brain and indicators:
await self.brain.evaluate(
self.symbol, interval, bucket_start, indicators
)
except Exception as e:
logger.error(f"Failed to update custom timeframes: {e}")
logger.error(f"Failed to update custom timeframes/indicators: {e}")
# Don't raise - this is non-critical
async def _on_error(self, error: Exception) -> None:

View File

@ -0,0 +1,160 @@
"""
Simulator
Handles account accounting, leverage, fees, and position management for backtesting.
"""
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
from datetime import datetime
from .brain import Decision # We might need to decouple this later, but reusing for now
@dataclass
class Trade:
entry_time: datetime
exit_time: Optional[datetime]
side: str # 'long' or 'short'
entry_price: float
exit_price: Optional[float]
size: float # Quantity of asset
leverage: float
pnl: float = 0.0
pnl_percent: float = 0.0
fees: float = 0.0
status: str = 'open' # 'open', 'closed'
class Account:
def __init__(self, initial_balance: float = 1000.0, maker_fee: float = 0.0002, taker_fee: float = 0.0005):
self.initial_balance = initial_balance
self.balance = initial_balance
self.equity = initial_balance
self.maker_fee = maker_fee
self.taker_fee = taker_fee
self.trades: List[Trade] = []
self.current_position: Optional[Trade] = None
self.margin_used = 0.0
def update_equity(self, current_price: float):
"""Update equity based on unrealized PnL of current position"""
if not self.current_position:
self.equity = self.balance
return
trade = self.current_position
if trade.side == 'long':
unrealized_pnl = (current_price - trade.entry_price) * trade.size
else:
unrealized_pnl = (trade.entry_price - current_price) * trade.size
self.equity = self.balance + unrealized_pnl
def open_position(self, time: datetime, side: str, price: float, leverage: float = 1.0, portion: float = 1.0):
"""
Open a position.
portion: 0.0 to 1.0 (percentage of available balance to use)
"""
if self.current_position:
# Already have a position, ignore for now (or could add to it)
return
# Calculate position size
# Margin = (Balance * portion)
# Position Value = Margin * Leverage
# Size = Position Value / Price
margin_to_use = self.balance * portion
position_value = margin_to_use * leverage
size = position_value / price
# Fee (Taker)
fee = position_value * self.taker_fee
self.balance -= fee # Deduct fee immediately
self.current_position = Trade(
entry_time=time,
exit_time=None,
side=side,
entry_price=price,
exit_price=None,
size=size,
leverage=leverage,
fees=fee
)
self.margin_used = margin_to_use
def close_position(self, time: datetime, price: float):
"""Close the current position"""
if not self.current_position:
return
trade = self.current_position
position_value = trade.size * price
# Calculate PnL
if trade.side == 'long':
pnl = (price - trade.entry_price) * trade.size
pnl_pct = (price - trade.entry_price) / trade.entry_price * trade.leverage * 100
else:
pnl = (trade.entry_price - price) * trade.size
pnl_pct = (trade.entry_price - price) / trade.entry_price * trade.leverage * 100
# Fee (Taker)
fee = position_value * self.taker_fee
self.balance -= fee
trade.fees += fee
# Update Balance
self.balance += pnl
self.margin_used = 0.0
# Update Trade Record
trade.exit_time = time
trade.exit_price = price
trade.pnl = pnl
trade.pnl_percent = pnl_pct
trade.status = 'closed'
self.trades.append(trade)
self.current_position = None
self.equity = self.balance
def get_position_dict(self) -> Optional[Dict[str, Any]]:
if not self.current_position:
return None
return {
'type': self.current_position.side,
'entry_price': self.current_position.entry_price,
'size': self.current_position.size,
'leverage': self.current_position.leverage
}
def get_stats(self) -> Dict[str, Any]:
wins = [t for t in self.trades if t.pnl > 0]
losses = [t for t in self.trades if t.pnl <= 0]
total_pnl = self.balance - self.initial_balance
total_pnl_pct = (total_pnl / self.initial_balance) * 100
return {
"initial_balance": self.initial_balance,
"final_balance": self.balance,
"total_pnl": total_pnl,
"total_pnl_pct": total_pnl_pct,
"total_trades": len(self.trades),
"win_count": len(wins),
"loss_count": len(losses),
"win_rate": (len(wins) / len(self.trades) * 100) if self.trades else 0.0,
"max_drawdown": 0.0, # Todo: implement DD tracking
"trades": [
{
"entry_time": t.entry_time.isoformat(),
"exit_time": t.exit_time.isoformat() if t.exit_time else None,
"side": t.side,
"entry_price": t.entry_price,
"exit_price": t.exit_price,
"pnl": t.pnl,
"pnl_pct": t.pnl_percent,
"fees": t.fees
}
for t in self.trades
]
}

68
src/strategies/base.py Normal file
View File

@ -0,0 +1,68 @@
"""
Base Strategy Interface
All strategies must inherit from this class.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Any, List, Optional
from enum import Enum
class SignalType(Enum):
OPEN_LONG = "open_long"
OPEN_SHORT = "open_short"
CLOSE_LONG = "close_long"
CLOSE_SHORT = "close_short"
HOLD = "hold"
@dataclass
class StrategySignal:
type: SignalType
confidence: float
reasoning: str
class BaseStrategy(ABC):
def __init__(self, config: Optional[Dict[str, Any]] = None):
self.config = config or {}
@property
@abstractmethod
def name(self) -> str:
"""Unique identifier for the strategy"""
pass
@property
@abstractmethod
def required_indicators(self) -> List[str]:
"""List of indicator names required by this strategy (e.g. ['ma44'])"""
pass
@property
@abstractmethod
def display_name(self) -> str:
"""User-friendly name for display in UI (e.g. 'MA44 Crossover')"""
pass
@property
@abstractmethod
def description(self) -> str:
"""Detailed description of how the strategy works"""
pass
@abstractmethod
def analyze(
self,
candle: Dict[str, Any],
indicators: Dict[str, float],
current_position: Optional[Dict[str, Any]] = None
) -> StrategySignal:
"""
Analyze market data and return a trading signal.
Args:
candle: Dictionary containing 'close', 'open', 'high', 'low', 'volume', 'time'
indicators: Dictionary of pre-computed indicator values
current_position: Details about current open position (if any)
{'type': 'long'/'short', 'entry_price': float, 'size': float}
"""
pass

View File

@ -0,0 +1,63 @@
"""
MA125 Strategy
Simple trend following strategy.
- Long when Price > MA125
- Short when Price < MA125
"""
from typing import Dict, Any, List, Optional
from .base import BaseStrategy, StrategySignal, SignalType
class MA125Strategy(BaseStrategy):
@property
def name(self) -> str:
return "ma125_strategy"
@property
def required_indicators(self) -> List[str]:
return ["ma125"]
@property
def display_name(self) -> str:
return "MA125 Strategy"
@property
def description(self) -> str:
return "Long-term trend following using 125-period moving average. Better for identifying major trends."
def analyze(
self,
candle: Dict[str, Any],
indicators: Dict[str, float],
current_position: Optional[Dict[str, Any]] = None
) -> StrategySignal:
price = candle['close']
ma125 = indicators.get('ma125')
if ma125 is None:
return StrategySignal(SignalType.HOLD, 0.0, "MA125 not available")
# Current position state
is_long = current_position and current_position.get('type') == 'long'
is_short = current_position and current_position.get('type') == 'short'
# Logic: Price > MA125 -> Bullish
if price > ma125:
if is_long:
return StrategySignal(SignalType.HOLD, 1.0, f"Price {price:.2f} > MA125 {ma125:.2f}. Stay Long.")
elif is_short:
return StrategySignal(SignalType.CLOSE_SHORT, 1.0, f"Price {price:.2f} crossed above MA125 {ma125:.2f}. Close Short.")
else:
return StrategySignal(SignalType.OPEN_LONG, 1.0, f"Price {price:.2f} > MA125 {ma125:.2f}. Open Long.")
# Logic: Price < MA125 -> Bearish
elif price < ma125:
if is_short:
return StrategySignal(SignalType.HOLD, 1.0, f"Price {price:.2f} < MA125 {ma125:.2f}. Stay Short.")
elif is_long:
return StrategySignal(SignalType.CLOSE_LONG, 1.0, f"Price {price:.2f} crossed below MA125 {ma125:.2f}. Close Long.")
else:
return StrategySignal(SignalType.OPEN_SHORT, 1.0, f"Price {price:.2f} < MA125 {ma125:.2f}. Open Short.")
return StrategySignal(SignalType.HOLD, 0.0, "Price == MA125")

View File

@ -0,0 +1,63 @@
"""
MA44 Strategy
Simple trend following strategy.
- Long when Price > MA44
- Short when Price < MA44
"""
from typing import Dict, Any, List, Optional
from .base import BaseStrategy, StrategySignal, SignalType
class MA44Strategy(BaseStrategy):
@property
def name(self) -> str:
return "ma44_strategy"
@property
def required_indicators(self) -> List[str]:
return ["ma44"]
@property
def display_name(self) -> str:
return "MA44 Strategy"
@property
def description(self) -> str:
return "Buy when price crosses above MA44, sell when below. Good for trending markets."
def analyze(
self,
candle: Dict[str, Any],
indicators: Dict[str, float],
current_position: Optional[Dict[str, Any]] = None
) -> StrategySignal:
price = candle['close']
ma44 = indicators.get('ma44')
if ma44 is None:
return StrategySignal(SignalType.HOLD, 0.0, "MA44 not available")
# Current position state
is_long = current_position and current_position.get('type') == 'long'
is_short = current_position and current_position.get('type') == 'short'
# Logic: Price > MA44 -> Bullish
if price > ma44:
if is_long:
return StrategySignal(SignalType.HOLD, 1.0, f"Price {price:.2f} > MA44 {ma44:.2f}. Stay Long.")
elif is_short:
return StrategySignal(SignalType.CLOSE_SHORT, 1.0, f"Price {price:.2f} crossed above MA44 {ma44:.2f}. Close Short.")
else:
return StrategySignal(SignalType.OPEN_LONG, 1.0, f"Price {price:.2f} > MA44 {ma44:.2f}. Open Long.")
# Logic: Price < MA44 -> Bearish
elif price < ma44:
if is_short:
return StrategySignal(SignalType.HOLD, 1.0, f"Price {price:.2f} < MA44 {ma44:.2f}. Stay Short.")
elif is_long:
return StrategySignal(SignalType.CLOSE_LONG, 1.0, f"Price {price:.2f} crossed below MA44 {ma44:.2f}. Close Long.")
else:
return StrategySignal(SignalType.OPEN_SHORT, 1.0, f"Price {price:.2f} < MA44 {ma44:.2f}. Open Short.")
return StrategySignal(SignalType.HOLD, 0.0, "Price == MA44")