Create api_client.py

This commit is contained in:
balgerion
2025-09-30 16:13:31 +02:00
committed by GitHub
parent 85bad099e8
commit d9e03384ff

View File

@ -0,0 +1,370 @@
"""Shared API client for Pstryk Energy integration with caching and rate limiting."""
import logging
import asyncio
import random
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
from email.utils import parsedate_to_datetime
import aiohttp
from homeassistant.core import HomeAssistant
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.update_coordinator import UpdateFailed
from homeassistant.helpers.translation import async_get_translations
from .const import API_URL, API_TIMEOUT, DOMAIN
_LOGGER = logging.getLogger(__name__)
class PstrykAPIClient:
"""Shared API client with caching, rate limiting, and proper error handling."""
def __init__(self, hass: HomeAssistant, api_key: str):
"""Initialize the API client."""
self.hass = hass
self.api_key = api_key
self._session: Optional[aiohttp.ClientSession] = None
self._translations: Dict[str, str] = {}
self._translations_loaded = False
# Rate limiting: {endpoint_key: {"retry_after": datetime, "backoff": float}}
self._rate_limits: Dict[str, Dict[str, Any]] = {}
self._rate_limit_lock = asyncio.Lock()
# Request throttling - limit concurrent requests
self._request_semaphore = asyncio.Semaphore(3) # Max 3 concurrent requests
# Deduplication - track in-flight requests
self._in_flight: Dict[str, asyncio.Task] = {}
self._in_flight_lock = asyncio.Lock()
@property
def session(self) -> aiohttp.ClientSession:
"""Get or create aiohttp session."""
if self._session is None:
self._session = async_get_clientsession(self.hass)
return self._session
async def _load_translations(self):
"""Load translations for error messages."""
if not self._translations_loaded:
try:
self._translations = await async_get_translations(
self.hass, self.hass.config.language, DOMAIN, ["debug"]
)
self._translations_loaded = True
# Debug: log sample keys to understand the format
if self._translations:
sample_keys = list(self._translations.keys())[:3]
_LOGGER.debug("Loaded %d translation keys, samples: %s",
len(self._translations), sample_keys)
except Exception as ex:
_LOGGER.warning("Failed to load translations for API client: %s", ex)
self._translations = {}
self._translations_loaded = True
def _t(self, key: str, **kwargs) -> str:
"""Get translated string with fallback."""
# Try different key formats as async_get_translations may return different formats
possible_keys = [
f"component.{DOMAIN}.debug.{key}", # Full format: component.pstryk.debug.key
f"{DOMAIN}.debug.{key}", # Domain format: pstryk.debug.key
f"debug.{key}", # Short format: debug.key
key # Just the key
]
template = None
for possible_key in possible_keys:
template = self._translations.get(possible_key)
if template:
break
# If translation not found, create a fallback message
if not template:
# Fallback patterns for common error types
if key == "api_error_html":
template = "API error {status} for {endpoint} (HTML error page received)"
elif key == "rate_limited":
template = "Endpoint '{endpoint}' is rate limited. Will retry after {seconds} seconds"
elif key == "waiting_rate_limit":
template = "Waiting {seconds} seconds for rate limit to clear"
else:
_LOGGER.debug("Translation key not found: %s (tried formats: %s)", key, possible_keys)
template = key
try:
return template.format(**kwargs)
except (KeyError, ValueError) as e:
_LOGGER.warning("Failed to format translation template '%s': %s", template, e)
return template
def _get_endpoint_key(self, url: str) -> str:
"""Extract endpoint key from URL for rate limiting."""
# Extract the main endpoint (e.g., "pricing", "prosumer-pricing", "energy-cost")
if "pricing/?resolution" in url:
return "pricing"
elif "prosumer-pricing/?resolution" in url:
return "prosumer-pricing"
elif "meter-data/energy-cost" in url:
return "energy-cost"
elif "meter-data/energy-usage" in url:
return "energy-usage"
return "unknown"
async def _check_rate_limit(self, endpoint_key: str) -> Optional[float]:
"""Check if we're rate limited and return wait time if needed."""
async with self._rate_limit_lock:
if endpoint_key in self._rate_limits:
limit_info = self._rate_limits[endpoint_key]
retry_after = limit_info.get("retry_after")
if retry_after and datetime.now() < retry_after:
wait_time = (retry_after - datetime.now()).total_seconds()
return wait_time
elif retry_after and datetime.now() >= retry_after:
# Rate limit expired, clear it
del self._rate_limits[endpoint_key]
return None
def _calculate_backoff(self, attempt: int, base_delay: float = 20.0) -> float:
"""Calculate exponential backoff with jitter."""
# Exponential backoff: base_delay * (2 ^ attempt)
backoff = base_delay * (2 ** attempt)
# Add jitter: ±20% randomization
jitter = backoff * 0.2 * (2 * random.random() - 1)
return max(1.0, backoff + jitter)
async def _handle_rate_limit(self, response: aiohttp.ClientResponse, endpoint_key: str):
"""Handle 429 rate limit response."""
# Ensure translations are loaded
await self._load_translations()
retry_after_header = response.headers.get("Retry-After")
wait_time = None
if retry_after_header:
try:
# Try parsing as seconds
wait_time = int(retry_after_header)
except ValueError:
# Try parsing as HTTP date
try:
retry_date = parsedate_to_datetime(retry_after_header)
wait_time = (retry_date - datetime.now()).total_seconds()
except Exception:
pass
# Fallback to 3600 seconds (1 hour) if not specified
if wait_time is None:
wait_time = 3600
retry_after_dt = datetime.now() + timedelta(seconds=wait_time)
async with self._rate_limit_lock:
self._rate_limits[endpoint_key] = {
"retry_after": retry_after_dt,
"backoff": wait_time
}
_LOGGER.warning(
self._t("rate_limited", endpoint=endpoint_key, seconds=int(wait_time))
)
async def _make_request(
self,
url: str,
max_retries: int = 3,
base_delay: float = 20.0
) -> Dict[str, Any]:
"""Make API request with retries, rate limiting, and deduplication."""
# Load translations if not already loaded
await self._load_translations()
endpoint_key = self._get_endpoint_key(url)
# Check if we're rate limited
wait_time = await self._check_rate_limit(endpoint_key)
if wait_time and wait_time > 0:
# If wait time is reasonable, wait
if wait_time <= 60:
_LOGGER.info(
self._t("waiting_rate_limit", seconds=int(wait_time))
)
await asyncio.sleep(wait_time)
else:
raise UpdateFailed(
f"API rate limited for {endpoint_key}. Please try again in {int(wait_time/60)} minutes."
)
headers = {
"Authorization": self.api_key,
"Accept": "application/json"
}
last_exception = None
for attempt in range(max_retries):
try:
# Use semaphore to limit concurrent requests
async with self._request_semaphore:
async with asyncio.timeout(API_TIMEOUT):
async with self.session.get(url, headers=headers) as response:
# Handle different status codes
if response.status == 200:
data = await response.json()
return data
elif response.status == 429:
# Handle rate limiting
await self._handle_rate_limit(response, endpoint_key)
# Retry with exponential backoff
if attempt < max_retries - 1:
backoff = self._calculate_backoff(attempt, base_delay)
_LOGGER.debug(
"Rate limited, retrying in %.1f seconds (attempt %d/%d)",
backoff, attempt + 1, max_retries
)
await asyncio.sleep(backoff)
continue
else:
raise UpdateFailed(
f"API rate limit exceeded after {max_retries} attempts"
)
elif response.status == 500:
error_text = await response.text()
# Extract plain text from HTML if present
if error_text.strip().startswith('<!doctype html>') or error_text.strip().startswith('<html'):
# Just log that it's HTML, not the whole HTML
_LOGGER.error(
self._t("api_error_html", status=500, endpoint=endpoint_key)
)
else:
# Log actual error text (truncated)
_LOGGER.error(
"API returned 500 for %s: %s",
endpoint_key, error_text[:100]
)
# Retry with backoff
if attempt < max_retries - 1:
backoff = self._calculate_backoff(attempt, base_delay)
_LOGGER.debug(
"Retrying after 500 error in %.1f seconds (attempt %d/%d)",
backoff, attempt + 1, max_retries
)
await asyncio.sleep(backoff)
continue
else:
raise UpdateFailed(
f"API server error (500) for {endpoint_key} after {max_retries} attempts"
)
elif response.status in (401, 403):
raise UpdateFailed(
f"Authentication failed (status {response.status}). Please check your API key."
)
elif response.status == 404:
raise UpdateFailed(
f"API endpoint not found (404): {endpoint_key}"
)
else:
error_text = await response.text()
# Clean HTML from error messages
if error_text.strip().startswith('<!doctype html>') or error_text.strip().startswith('<html'):
_LOGGER.error(
self._t("api_error_html", status=response.status, endpoint=endpoint_key)
)
else:
_LOGGER.error(
"API error %d for %s: %s",
response.status, endpoint_key, error_text[:100]
)
# For other errors, retry with backoff
if attempt < max_retries - 1:
backoff = self._calculate_backoff(attempt, base_delay)
await asyncio.sleep(backoff)
continue
else:
raise UpdateFailed(
f"API error {response.status} for {endpoint_key}"
)
except asyncio.TimeoutError as err:
last_exception = err
_LOGGER.warning(
"Timeout fetching from %s (attempt %d/%d)",
endpoint_key, attempt + 1, max_retries
)
if attempt < max_retries - 1:
backoff = self._calculate_backoff(attempt, base_delay)
await asyncio.sleep(backoff)
continue
except aiohttp.ClientError as err:
last_exception = err
_LOGGER.warning(
"Network error fetching from %s: %s (attempt %d/%d)",
endpoint_key, err, attempt + 1, max_retries
)
if attempt < max_retries - 1:
backoff = self._calculate_backoff(attempt, base_delay)
await asyncio.sleep(backoff)
continue
except Exception as err:
last_exception = err
_LOGGER.exception(
"Unexpected error fetching from %s: %s",
endpoint_key, err
)
break
# All retries exhausted, raise the error
if last_exception:
raise UpdateFailed(
f"Failed to fetch data from {endpoint_key} after {max_retries} attempts"
) from last_exception
raise UpdateFailed(f"Failed to fetch data from {endpoint_key}")
async def fetch(
self,
url: str,
max_retries: int = 3,
base_delay: float = 20.0
) -> Dict[str, Any]:
"""Fetch data with deduplication of concurrent requests."""
# Check if there's already an in-flight request for this URL
async with self._in_flight_lock:
if url in self._in_flight:
_LOGGER.debug("Deduplicating request for %s", url)
# Wait for the existing request to complete
try:
return await self._in_flight[url]
except Exception:
# If the in-flight request failed, create a new one
pass
# Create new request task
task = asyncio.create_task(
self._make_request(url, max_retries, base_delay)
)
self._in_flight[url] = task
try:
result = await task
return result
finally:
# Remove from in-flight requests
async with self._in_flight_lock:
self._in_flight.pop(url, None)