354 lines
17 KiB
Python
354 lines
17 KiB
Python
import logging
|
|
import time
|
|
import json
|
|
import os
|
|
from datetime import datetime, timezone
|
|
from hyperliquid.info import Info
|
|
from hyperliquid.utils import constants
|
|
|
|
from strategies.base_strategy import BaseStrategy
|
|
|
|
class CopyTraderStrategy(BaseStrategy):
|
|
"""
|
|
An event-driven strategy that monitors a target wallet address and
|
|
copies its trades for a specific set of allowed coins.
|
|
|
|
This strategy is STATELESS. It translates a target's fill direction
|
|
(e.g., "Open Long") directly into an explicit signal
|
|
(e.g., "OPEN_LONG") for the PositionManager.
|
|
"""
|
|
def __init__(self, strategy_name: str, params: dict, trade_signal_queue, shared_status: dict = None):
|
|
# --- MODIFIED: Pass the correct queue to the parent ---
|
|
# The event-driven copy trader should send orders to the order_execution_queue
|
|
# We will assume the queue passed in is the correct one (as setup in main_app.py)
|
|
super().__init__(strategy_name, params, trade_signal_queue, shared_status)
|
|
|
|
self.target_address = self.params.get("target_address", "").lower()
|
|
self.coins_to_copy = self.params.get("coins_to_copy", {})
|
|
# Convert all coin keys to uppercase for consistency
|
|
self.coins_to_copy = {k.upper(): v for k, v in self.coins_to_copy.items()}
|
|
self.allowed_coins = list(self.coins_to_copy.keys())
|
|
|
|
if not self.target_address:
|
|
logging.error("No 'target_address' specified in parameters for copy trader.")
|
|
raise ValueError("target_address is required")
|
|
if not self.allowed_coins:
|
|
logging.warning("No 'coins_to_copy' configured. This strategy will not copy any trades.")
|
|
|
|
self.info = None # Will be initialized in the run loop
|
|
|
|
# --- REMOVED: All local state management ---
|
|
# self.position_state_file = ...
|
|
# self.current_positions = ...
|
|
|
|
# --- MODIFIED: Check if shared_status is None before using it ---
|
|
if self.shared_status is None:
|
|
logging.warning("No shared_status dictionary provided. Initializing a new one.")
|
|
self.shared_status = {}
|
|
|
|
self.current_signal = self.shared_status.get("current_signal", "WAIT")
|
|
self.signal_price = self.shared_status.get("signal_price")
|
|
self.last_signal_change_utc = self.shared_status.get("last_signal_change_utc")
|
|
|
|
self.start_time_utc = datetime.now(timezone.utc)
|
|
logging.info(f"Strategy initialized. Ignoring all trades before {self.start_time_utc.isoformat()}")
|
|
|
|
# --- REMOVED: _load_position_state ---
|
|
# --- REMOVED: _save_position_state ---
|
|
|
|
def calculate_signals(self, df):
|
|
# This strategy is event-driven, so it does not use polling-based signal calculation.
|
|
pass
|
|
|
|
def send_explicit_signal(self, signal: str, coin: str, price: float, trade_params: dict, size: float):
|
|
"""Helper to send a formatted signal to the PositionManager."""
|
|
config = {
|
|
# --- MODIFIED: Ensure agent is read from params ---
|
|
"agent": self.params.get("agent"),
|
|
"parameters": trade_params
|
|
}
|
|
|
|
# --- MODIFIED: Use self.trade_signal_queue (which is the queue passed in) ---
|
|
self.trade_signal_queue.put({
|
|
"strategy_name": self.strategy_name,
|
|
"signal": signal, # e.g., "OPEN_LONG", "CLOSE_SHORT"
|
|
"coin": coin,
|
|
"signal_price": price,
|
|
"config": config,
|
|
"size": size # Explicitly pass size (or leverage for leverage updates)
|
|
})
|
|
logging.info(f"Explicit signal SENT: {signal} {coin} @ {price}, Size: {size}")
|
|
|
|
def on_fill_message(self, message):
|
|
"""
|
|
This is the callback function that gets triggered by the WebSocket
|
|
every time the monitored address has an event.
|
|
"""
|
|
try:
|
|
# --- NEW: Add logging to see ALL messages ---
|
|
logging.debug(f"Received WebSocket message: {message}")
|
|
|
|
channel = message.get("channel")
|
|
if channel not in ("user", "userFills", "userEvents"):
|
|
# --- NEW: Added debug logging ---
|
|
logging.debug(f"Ignoring message from unhandled channel: {channel}")
|
|
return
|
|
|
|
data = message.get("data")
|
|
if not data:
|
|
# --- NEW: Added debug logging ---
|
|
logging.debug("Message received with no 'data' field. Ignoring.")
|
|
return
|
|
|
|
# --- NEW: Check for user address FIRST ---
|
|
user_address = data.get("user", "").lower()
|
|
if not user_address:
|
|
logging.debug("Received message with 'data' but no 'user'. Ignoring.")
|
|
return
|
|
|
|
# --- MODIFIED: Check for 'fills' vs. other event types ---
|
|
# This check is still valid for userFills
|
|
if "fills" not in data or not data.get("fills"):
|
|
# This is a userEvent, but not a fill (e.g., order placement, cancel, withdrawal)
|
|
event_type = data.get("type") # e.g., 'order', 'cancel', 'withdrawal'
|
|
if event_type:
|
|
logging.debug(f"Received non-fill user event: '{event_type}'. Ignoring.")
|
|
else:
|
|
logging.debug(f"Received 'data' message with no 'fills'. Ignoring.")
|
|
return
|
|
|
|
# --- This line is now safe to run ---
|
|
if user_address != self.target_address:
|
|
# This shouldn't happen if the subscription is correct, but good to check
|
|
logging.warning(f"Received fill for wrong user: {user_address}")
|
|
return
|
|
|
|
fills = data.get("fills")
|
|
logging.debug(f"Received {len(fills)} fill(s) for user {user_address}")
|
|
|
|
for fill in fills:
|
|
# Check if the trade is new or historical
|
|
trade_time = datetime.fromtimestamp(fill['time'] / 1000, tz=timezone.utc)
|
|
if trade_time < self.start_time_utc:
|
|
logging.info(f"Ignoring stale/historical trade from {trade_time.isoformat()}")
|
|
continue
|
|
|
|
coin = fill.get('coin').upper()
|
|
|
|
if coin in self.allowed_coins:
|
|
price = float(fill.get('px'))
|
|
|
|
# --- MODIFIED: Use the target's fill size ---
|
|
fill_size = float(fill.get('sz')) # Target's size
|
|
|
|
if fill_size == 0:
|
|
logging.warning(f"Ignoring fill with size 0.")
|
|
continue
|
|
|
|
# --- NEW: Get the fill direction ---
|
|
# "dir": "Open Long", "Close Long", "Open Short", "Close Short"
|
|
fill_direction = fill.get("dir")
|
|
|
|
# --- NEW: Get startPosition to calculate flip sizes ---
|
|
start_pos_size = float(fill.get('startPosition', 0.0))
|
|
|
|
if not fill_direction:
|
|
logging.warning(f"Fill message missing 'dir'. Ignoring fill: {fill}")
|
|
continue
|
|
|
|
# Get our strategy's configured leverage for this coin
|
|
coin_config = self.coins_to_copy.get(coin)
|
|
|
|
# --- REMOVED: Check for coin_config.get("size") ---
|
|
# --- REMOVED: strategy_trade_size = coin_config.get("size") ---
|
|
|
|
# Prepare config for the signal
|
|
trade_params = self.params.copy()
|
|
if coin_config:
|
|
trade_params.update(coin_config)
|
|
|
|
# --- REMOVED: All stateful logic (current_local_pos, etc.) ---
|
|
|
|
# --- MODIFIED: Expanded logic to handle flip directions ---
|
|
signal_sent = False
|
|
dashboard_signal = ""
|
|
|
|
if fill_direction == "Open Long":
|
|
logging.warning(f"[{coin}] Target action: {fill_direction}. Sending signal: OPEN_LONG")
|
|
self.send_explicit_signal("OPEN_LONG", coin, price, trade_params, fill_size)
|
|
signal_sent = True
|
|
dashboard_signal = "OPEN_LONG"
|
|
|
|
elif fill_direction == "Close Long":
|
|
logging.warning(f"[{coin}] Target action: {fill_direction}. Sending signal: CLOSE_LONG")
|
|
self.send_explicit_signal("CLOSE_LONG", coin, price, trade_params, fill_size)
|
|
signal_sent = True
|
|
dashboard_signal = "CLOSE_LONG"
|
|
|
|
elif fill_direction == "Open Short":
|
|
logging.warning(f"[{coin}] Target action: {fill_direction}. Sending signal: OPEN_SHORT")
|
|
self.send_explicit_signal("OPEN_SHORT", coin, price, trade_params, fill_size)
|
|
signal_sent = True
|
|
dashboard_signal = "OPEN_SHORT"
|
|
|
|
elif fill_direction == "Close Short":
|
|
logging.warning(f"[{coin}] Target action: {fill_direction}. Sending signal: CLOSE_SHORT")
|
|
self.send_explicit_signal("CLOSE_SHORT", coin, price, trade_params, fill_size)
|
|
signal_sent = True
|
|
dashboard_signal = "CLOSE_SHORT"
|
|
|
|
elif fill_direction == "Short > Long":
|
|
logging.warning(f"[{coin}] Target action: {fill_direction}. Sending CLOSE_SHORT then OPEN_LONG.")
|
|
close_size = abs(start_pos_size)
|
|
open_size = fill_size - close_size
|
|
|
|
if close_size > 0:
|
|
self.send_explicit_signal("CLOSE_SHORT", coin, price, trade_params, close_size)
|
|
|
|
if open_size > 0:
|
|
self.send_explicit_signal("OPEN_LONG", coin, price, trade_params, open_size)
|
|
|
|
signal_sent = True
|
|
dashboard_signal = "FLIP_TO_LONG"
|
|
|
|
elif fill_direction == "Long > Short":
|
|
logging.warning(f"[{coin}] Target action: {fill_direction}. Sending CLOSE_LONG then OPEN_SHORT.")
|
|
close_size = abs(start_pos_size)
|
|
open_size = fill_size - close_size
|
|
|
|
if close_size > 0:
|
|
self.send_explicit_signal("CLOSE_LONG", coin, price, trade_params, close_size)
|
|
|
|
if open_size > 0:
|
|
self.send_explicit_signal("OPEN_SHORT", coin, price, trade_params, open_size)
|
|
|
|
signal_sent = True
|
|
dashboard_signal = "FLIP_TO_SHORT"
|
|
|
|
|
|
if signal_sent:
|
|
# Update dashboard status
|
|
self.current_signal = dashboard_signal # Show the action
|
|
self.signal_price = price
|
|
self.last_signal_change_utc = trade_time.isoformat()
|
|
self.coin = coin # Update coin for dashboard
|
|
self.size = fill_size # Update size for dashboard
|
|
self._save_status() # For dashboard
|
|
|
|
logging.info(f"Source trade logged: {json.dumps(fill)}")
|
|
else:
|
|
logging.info(f"[{coin}] Ignoring unhandled fill direction: {fill_direction}")
|
|
else:
|
|
logging.info(f"Ignoring fill for unmonitored coin: {coin}")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error in on_fill_message: {e}", exc_info=True)
|
|
|
|
def _connect_and_subscribe(self):
|
|
"""
|
|
Establishes a new WebSocket connection and subscribes to the userFills channel.
|
|
"""
|
|
try:
|
|
logging.info("Connecting to Hyperliquid WebSocket...")
|
|
self.info = Info(constants.MAINNET_API_URL, skip_ws=False)
|
|
|
|
# --- MODIFIED: Reverted to 'userFills' as requested ---
|
|
subscription = {"type": "userFills", "user": self.target_address}
|
|
self.info.subscribe(subscription, self.on_fill_message)
|
|
logging.info(f"Subscribed to 'userFills' for target address: {self.target_address}")
|
|
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"Failed to connect or subscribe: {e}")
|
|
self.info = None
|
|
return False
|
|
|
|
def run_event_loop(self):
|
|
"""
|
|
This method overrides the default polling loop. It establishes a
|
|
persistent WebSocket connection and runs a watchdog to ensure
|
|
it stays connected.
|
|
"""
|
|
try:
|
|
if not self._connect_and_subscribe():
|
|
# If connection fails on start, wait 60s before letting the process restart
|
|
time.sleep(60)
|
|
return
|
|
|
|
# --- MODIFIED: Add a small delay to ensure Info object is ready for REST calls ---
|
|
logging.info("Connection established. Waiting 2 seconds for Info client to be ready...")
|
|
time.sleep(2)
|
|
# --- END MODIFICATION ---
|
|
|
|
# --- NEW: Set initial leverage for all monitored coins ---
|
|
logging.info("Setting initial leverage for all monitored coins...")
|
|
try:
|
|
all_mids = self.info.all_mids()
|
|
for coin_key, coin_config in self.coins_to_copy.items():
|
|
coin = coin_key.upper()
|
|
# Use a failsafe price of 1.0 if coin not in mids (e.g., new listing)
|
|
current_price = float(all_mids.get(coin, 1.0))
|
|
|
|
leverage_long = coin_config.get('leverage_long', 2)
|
|
leverage_short = coin_config.get('leverage_short', 2)
|
|
|
|
# Prepare config for the signal
|
|
trade_params = self.params.copy()
|
|
trade_params.update(coin_config)
|
|
|
|
# Send LONG leverage update
|
|
# The 'size' param is used to pass the leverage value for this signal type
|
|
self.send_explicit_signal("UPDATE_LEVERAGE_LONG", coin, current_price, trade_params, leverage_long)
|
|
|
|
# Send SHORT leverage update
|
|
self.send_explicit_signal("UPDATE_LEVERAGE_SHORT", coin, current_price, trade_params, leverage_short)
|
|
|
|
logging.info(f"Sent initial leverage signals for {coin} (Long: {leverage_long}x, Short: {leverage_short}x)")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Failed to set initial leverage: {e}", exc_info=True)
|
|
# --- END NEW LEVERAGE LOGIC ---
|
|
|
|
# Save the initial "WAIT" status
|
|
self._save_status()
|
|
|
|
while True:
|
|
try:
|
|
time.sleep(15) # Check the connection every 15 seconds
|
|
|
|
if self.info is None or not self.info.ws_manager.is_alive():
|
|
logging.error(f"WebSocket connection lost. Attempting to reconnect...")
|
|
|
|
if self.info and self.info.ws_manager:
|
|
try:
|
|
self.info.ws_manager.stop()
|
|
except Exception as e:
|
|
logging.error(f"Error stopping old ws_manager: {e}")
|
|
|
|
if not self._connect_and_subscribe():
|
|
logging.error("Reconnect failed, will retry in 15s.")
|
|
else:
|
|
logging.info("Successfully reconnected to WebSocket.")
|
|
self._save_status()
|
|
else:
|
|
logging.debug("Watchdog check: WebSocket connection is active.")
|
|
|
|
except Exception as e:
|
|
logging.error(f"An error occurred in the watchdog loop: {e}", exc_info=True)
|
|
|
|
except KeyboardInterrupt:
|
|
# --- MODIFIED: No positions to close, just exit ---
|
|
logging.warning(f"Shutdown signal received. Exiting strategy '{self.strategy_name}'.")
|
|
|
|
except Exception as e:
|
|
logging.error(f"An unhandled error occurred in run_event_loop: {e}", exc_info=True)
|
|
|
|
finally:
|
|
if self.info and self.info.ws_manager and self.info.ws_manager.is_alive():
|
|
try:
|
|
self.info.ws_manager.stop()
|
|
logging.info("WebSocket connection stopped.")
|
|
except Exception as e:
|
|
logging.error(f"Error stopping ws_manager on exit: {e}")
|
|
|