From d9e03384ffd4403d3ec566c876fe742796d758b3 Mon Sep 17 00:00:00 2001 From: balgerion <133121849+balgerion@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:13:31 +0200 Subject: [PATCH] Create api_client.py --- custom_components/pstryk/api_client.py | 370 +++++++++++++++++++++++++ 1 file changed, 370 insertions(+) create mode 100644 custom_components/pstryk/api_client.py diff --git a/custom_components/pstryk/api_client.py b/custom_components/pstryk/api_client.py new file mode 100644 index 0000000..4d9e68c --- /dev/null +++ b/custom_components/pstryk/api_client.py @@ -0,0 +1,370 @@ +"""Shared API client for Pstryk Energy integration with caching and rate limiting.""" +import logging +import asyncio +import random +from datetime import datetime, timedelta +from typing import Any, Dict, Optional +from email.utils import parsedate_to_datetime + +import aiohttp +from homeassistant.core import HomeAssistant +from homeassistant.helpers.aiohttp_client import async_get_clientsession +from homeassistant.helpers.update_coordinator import UpdateFailed +from homeassistant.helpers.translation import async_get_translations + +from .const import API_URL, API_TIMEOUT, DOMAIN + +_LOGGER = logging.getLogger(__name__) + + +class PstrykAPIClient: + """Shared API client with caching, rate limiting, and proper error handling.""" + + def __init__(self, hass: HomeAssistant, api_key: str): + """Initialize the API client.""" + self.hass = hass + self.api_key = api_key + self._session: Optional[aiohttp.ClientSession] = None + self._translations: Dict[str, str] = {} + self._translations_loaded = False + + # Rate limiting: {endpoint_key: {"retry_after": datetime, "backoff": float}} + self._rate_limits: Dict[str, Dict[str, Any]] = {} + self._rate_limit_lock = asyncio.Lock() + + # Request throttling - limit concurrent requests + self._request_semaphore = asyncio.Semaphore(3) # Max 3 concurrent requests + + # Deduplication - track in-flight requests + self._in_flight: Dict[str, asyncio.Task] = {} + self._in_flight_lock = asyncio.Lock() + + @property + def session(self) -> aiohttp.ClientSession: + """Get or create aiohttp session.""" + if self._session is None: + self._session = async_get_clientsession(self.hass) + return self._session + + async def _load_translations(self): + """Load translations for error messages.""" + if not self._translations_loaded: + try: + self._translations = await async_get_translations( + self.hass, self.hass.config.language, DOMAIN, ["debug"] + ) + self._translations_loaded = True + # Debug: log sample keys to understand the format + if self._translations: + sample_keys = list(self._translations.keys())[:3] + _LOGGER.debug("Loaded %d translation keys, samples: %s", + len(self._translations), sample_keys) + except Exception as ex: + _LOGGER.warning("Failed to load translations for API client: %s", ex) + self._translations = {} + self._translations_loaded = True + + def _t(self, key: str, **kwargs) -> str: + """Get translated string with fallback.""" + # Try different key formats as async_get_translations may return different formats + possible_keys = [ + f"component.{DOMAIN}.debug.{key}", # Full format: component.pstryk.debug.key + f"{DOMAIN}.debug.{key}", # Domain format: pstryk.debug.key + f"debug.{key}", # Short format: debug.key + key # Just the key + ] + + template = None + for possible_key in possible_keys: + template = self._translations.get(possible_key) + if template: + break + + # If translation not found, create a fallback message + if not template: + # Fallback patterns for common error types + if key == "api_error_html": + template = "API error {status} for {endpoint} (HTML error page received)" + elif key == "rate_limited": + template = "Endpoint '{endpoint}' is rate limited. Will retry after {seconds} seconds" + elif key == "waiting_rate_limit": + template = "Waiting {seconds} seconds for rate limit to clear" + else: + _LOGGER.debug("Translation key not found: %s (tried formats: %s)", key, possible_keys) + template = key + + try: + return template.format(**kwargs) + except (KeyError, ValueError) as e: + _LOGGER.warning("Failed to format translation template '%s': %s", template, e) + return template + + def _get_endpoint_key(self, url: str) -> str: + """Extract endpoint key from URL for rate limiting.""" + # Extract the main endpoint (e.g., "pricing", "prosumer-pricing", "energy-cost") + if "pricing/?resolution" in url: + return "pricing" + elif "prosumer-pricing/?resolution" in url: + return "prosumer-pricing" + elif "meter-data/energy-cost" in url: + return "energy-cost" + elif "meter-data/energy-usage" in url: + return "energy-usage" + return "unknown" + + async def _check_rate_limit(self, endpoint_key: str) -> Optional[float]: + """Check if we're rate limited and return wait time if needed.""" + async with self._rate_limit_lock: + if endpoint_key in self._rate_limits: + limit_info = self._rate_limits[endpoint_key] + retry_after = limit_info.get("retry_after") + + if retry_after and datetime.now() < retry_after: + wait_time = (retry_after - datetime.now()).total_seconds() + return wait_time + elif retry_after and datetime.now() >= retry_after: + # Rate limit expired, clear it + del self._rate_limits[endpoint_key] + + return None + + def _calculate_backoff(self, attempt: int, base_delay: float = 20.0) -> float: + """Calculate exponential backoff with jitter.""" + # Exponential backoff: base_delay * (2 ^ attempt) + backoff = base_delay * (2 ** attempt) + # Add jitter: ±20% randomization + jitter = backoff * 0.2 * (2 * random.random() - 1) + return max(1.0, backoff + jitter) + + async def _handle_rate_limit(self, response: aiohttp.ClientResponse, endpoint_key: str): + """Handle 429 rate limit response.""" + # Ensure translations are loaded + await self._load_translations() + + retry_after_header = response.headers.get("Retry-After") + wait_time = None + + if retry_after_header: + try: + # Try parsing as seconds + wait_time = int(retry_after_header) + except ValueError: + # Try parsing as HTTP date + try: + retry_date = parsedate_to_datetime(retry_after_header) + wait_time = (retry_date - datetime.now()).total_seconds() + except Exception: + pass + + # Fallback to 3600 seconds (1 hour) if not specified + if wait_time is None: + wait_time = 3600 + + retry_after_dt = datetime.now() + timedelta(seconds=wait_time) + + async with self._rate_limit_lock: + self._rate_limits[endpoint_key] = { + "retry_after": retry_after_dt, + "backoff": wait_time + } + + _LOGGER.warning( + self._t("rate_limited", endpoint=endpoint_key, seconds=int(wait_time)) + ) + + + async def _make_request( + self, + url: str, + max_retries: int = 3, + base_delay: float = 20.0 + ) -> Dict[str, Any]: + """Make API request with retries, rate limiting, and deduplication.""" + # Load translations if not already loaded + await self._load_translations() + + endpoint_key = self._get_endpoint_key(url) + + # Check if we're rate limited + wait_time = await self._check_rate_limit(endpoint_key) + if wait_time and wait_time > 0: + # If wait time is reasonable, wait + if wait_time <= 60: + _LOGGER.info( + self._t("waiting_rate_limit", seconds=int(wait_time)) + ) + await asyncio.sleep(wait_time) + else: + raise UpdateFailed( + f"API rate limited for {endpoint_key}. Please try again in {int(wait_time/60)} minutes." + ) + + headers = { + "Authorization": self.api_key, + "Accept": "application/json" + } + + last_exception = None + + for attempt in range(max_retries): + try: + # Use semaphore to limit concurrent requests + async with self._request_semaphore: + async with asyncio.timeout(API_TIMEOUT): + async with self.session.get(url, headers=headers) as response: + # Handle different status codes + if response.status == 200: + data = await response.json() + return data + + elif response.status == 429: + # Handle rate limiting + await self._handle_rate_limit(response, endpoint_key) + + # Retry with exponential backoff + if attempt < max_retries - 1: + backoff = self._calculate_backoff(attempt, base_delay) + _LOGGER.debug( + "Rate limited, retrying in %.1f seconds (attempt %d/%d)", + backoff, attempt + 1, max_retries + ) + await asyncio.sleep(backoff) + continue + else: + raise UpdateFailed( + f"API rate limit exceeded after {max_retries} attempts" + ) + + elif response.status == 500: + error_text = await response.text() + # Extract plain text from HTML if present + if error_text.strip().startswith('') or error_text.strip().startswith('') or error_text.strip().startswith(' Dict[str, Any]: + """Fetch data with deduplication of concurrent requests.""" + # Check if there's already an in-flight request for this URL + async with self._in_flight_lock: + if url in self._in_flight: + _LOGGER.debug("Deduplicating request for %s", url) + # Wait for the existing request to complete + try: + return await self._in_flight[url] + except Exception: + # If the in-flight request failed, create a new one + pass + + # Create new request task + task = asyncio.create_task( + self._make_request(url, max_retries, base_delay) + ) + self._in_flight[url] = task + + try: + result = await task + return result + finally: + # Remove from in-flight requests + async with self._in_flight_lock: + self._in_flight.pop(url, None)