Create api_client.py
This commit is contained in:
370
custom_components/pstryk/api_client.py
Normal file
370
custom_components/pstryk/api_client.py
Normal file
@ -0,0 +1,370 @@
|
||||
"""Shared API client for Pstryk Energy integration with caching and rate limiting."""
|
||||
import logging
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional
|
||||
from email.utils import parsedate_to_datetime
|
||||
|
||||
import aiohttp
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.update_coordinator import UpdateFailed
|
||||
from homeassistant.helpers.translation import async_get_translations
|
||||
|
||||
from .const import API_URL, API_TIMEOUT, DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PstrykAPIClient:
|
||||
"""Shared API client with caching, rate limiting, and proper error handling."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, api_key: str):
|
||||
"""Initialize the API client."""
|
||||
self.hass = hass
|
||||
self.api_key = api_key
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._translations: Dict[str, str] = {}
|
||||
self._translations_loaded = False
|
||||
|
||||
# Rate limiting: {endpoint_key: {"retry_after": datetime, "backoff": float}}
|
||||
self._rate_limits: Dict[str, Dict[str, Any]] = {}
|
||||
self._rate_limit_lock = asyncio.Lock()
|
||||
|
||||
# Request throttling - limit concurrent requests
|
||||
self._request_semaphore = asyncio.Semaphore(3) # Max 3 concurrent requests
|
||||
|
||||
# Deduplication - track in-flight requests
|
||||
self._in_flight: Dict[str, asyncio.Task] = {}
|
||||
self._in_flight_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None:
|
||||
self._session = async_get_clientsession(self.hass)
|
||||
return self._session
|
||||
|
||||
async def _load_translations(self):
|
||||
"""Load translations for error messages."""
|
||||
if not self._translations_loaded:
|
||||
try:
|
||||
self._translations = await async_get_translations(
|
||||
self.hass, self.hass.config.language, DOMAIN, ["debug"]
|
||||
)
|
||||
self._translations_loaded = True
|
||||
# Debug: log sample keys to understand the format
|
||||
if self._translations:
|
||||
sample_keys = list(self._translations.keys())[:3]
|
||||
_LOGGER.debug("Loaded %d translation keys, samples: %s",
|
||||
len(self._translations), sample_keys)
|
||||
except Exception as ex:
|
||||
_LOGGER.warning("Failed to load translations for API client: %s", ex)
|
||||
self._translations = {}
|
||||
self._translations_loaded = True
|
||||
|
||||
def _t(self, key: str, **kwargs) -> str:
|
||||
"""Get translated string with fallback."""
|
||||
# Try different key formats as async_get_translations may return different formats
|
||||
possible_keys = [
|
||||
f"component.{DOMAIN}.debug.{key}", # Full format: component.pstryk.debug.key
|
||||
f"{DOMAIN}.debug.{key}", # Domain format: pstryk.debug.key
|
||||
f"debug.{key}", # Short format: debug.key
|
||||
key # Just the key
|
||||
]
|
||||
|
||||
template = None
|
||||
for possible_key in possible_keys:
|
||||
template = self._translations.get(possible_key)
|
||||
if template:
|
||||
break
|
||||
|
||||
# If translation not found, create a fallback message
|
||||
if not template:
|
||||
# Fallback patterns for common error types
|
||||
if key == "api_error_html":
|
||||
template = "API error {status} for {endpoint} (HTML error page received)"
|
||||
elif key == "rate_limited":
|
||||
template = "Endpoint '{endpoint}' is rate limited. Will retry after {seconds} seconds"
|
||||
elif key == "waiting_rate_limit":
|
||||
template = "Waiting {seconds} seconds for rate limit to clear"
|
||||
else:
|
||||
_LOGGER.debug("Translation key not found: %s (tried formats: %s)", key, possible_keys)
|
||||
template = key
|
||||
|
||||
try:
|
||||
return template.format(**kwargs)
|
||||
except (KeyError, ValueError) as e:
|
||||
_LOGGER.warning("Failed to format translation template '%s': %s", template, e)
|
||||
return template
|
||||
|
||||
def _get_endpoint_key(self, url: str) -> str:
|
||||
"""Extract endpoint key from URL for rate limiting."""
|
||||
# Extract the main endpoint (e.g., "pricing", "prosumer-pricing", "energy-cost")
|
||||
if "pricing/?resolution" in url:
|
||||
return "pricing"
|
||||
elif "prosumer-pricing/?resolution" in url:
|
||||
return "prosumer-pricing"
|
||||
elif "meter-data/energy-cost" in url:
|
||||
return "energy-cost"
|
||||
elif "meter-data/energy-usage" in url:
|
||||
return "energy-usage"
|
||||
return "unknown"
|
||||
|
||||
async def _check_rate_limit(self, endpoint_key: str) -> Optional[float]:
|
||||
"""Check if we're rate limited and return wait time if needed."""
|
||||
async with self._rate_limit_lock:
|
||||
if endpoint_key in self._rate_limits:
|
||||
limit_info = self._rate_limits[endpoint_key]
|
||||
retry_after = limit_info.get("retry_after")
|
||||
|
||||
if retry_after and datetime.now() < retry_after:
|
||||
wait_time = (retry_after - datetime.now()).total_seconds()
|
||||
return wait_time
|
||||
elif retry_after and datetime.now() >= retry_after:
|
||||
# Rate limit expired, clear it
|
||||
del self._rate_limits[endpoint_key]
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_backoff(self, attempt: int, base_delay: float = 20.0) -> float:
|
||||
"""Calculate exponential backoff with jitter."""
|
||||
# Exponential backoff: base_delay * (2 ^ attempt)
|
||||
backoff = base_delay * (2 ** attempt)
|
||||
# Add jitter: ±20% randomization
|
||||
jitter = backoff * 0.2 * (2 * random.random() - 1)
|
||||
return max(1.0, backoff + jitter)
|
||||
|
||||
async def _handle_rate_limit(self, response: aiohttp.ClientResponse, endpoint_key: str):
|
||||
"""Handle 429 rate limit response."""
|
||||
# Ensure translations are loaded
|
||||
await self._load_translations()
|
||||
|
||||
retry_after_header = response.headers.get("Retry-After")
|
||||
wait_time = None
|
||||
|
||||
if retry_after_header:
|
||||
try:
|
||||
# Try parsing as seconds
|
||||
wait_time = int(retry_after_header)
|
||||
except ValueError:
|
||||
# Try parsing as HTTP date
|
||||
try:
|
||||
retry_date = parsedate_to_datetime(retry_after_header)
|
||||
wait_time = (retry_date - datetime.now()).total_seconds()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback to 3600 seconds (1 hour) if not specified
|
||||
if wait_time is None:
|
||||
wait_time = 3600
|
||||
|
||||
retry_after_dt = datetime.now() + timedelta(seconds=wait_time)
|
||||
|
||||
async with self._rate_limit_lock:
|
||||
self._rate_limits[endpoint_key] = {
|
||||
"retry_after": retry_after_dt,
|
||||
"backoff": wait_time
|
||||
}
|
||||
|
||||
_LOGGER.warning(
|
||||
self._t("rate_limited", endpoint=endpoint_key, seconds=int(wait_time))
|
||||
)
|
||||
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
url: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 20.0
|
||||
) -> Dict[str, Any]:
|
||||
"""Make API request with retries, rate limiting, and deduplication."""
|
||||
# Load translations if not already loaded
|
||||
await self._load_translations()
|
||||
|
||||
endpoint_key = self._get_endpoint_key(url)
|
||||
|
||||
# Check if we're rate limited
|
||||
wait_time = await self._check_rate_limit(endpoint_key)
|
||||
if wait_time and wait_time > 0:
|
||||
# If wait time is reasonable, wait
|
||||
if wait_time <= 60:
|
||||
_LOGGER.info(
|
||||
self._t("waiting_rate_limit", seconds=int(wait_time))
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
raise UpdateFailed(
|
||||
f"API rate limited for {endpoint_key}. Please try again in {int(wait_time/60)} minutes."
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": self.api_key,
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Use semaphore to limit concurrent requests
|
||||
async with self._request_semaphore:
|
||||
async with asyncio.timeout(API_TIMEOUT):
|
||||
async with self.session.get(url, headers=headers) as response:
|
||||
# Handle different status codes
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return data
|
||||
|
||||
elif response.status == 429:
|
||||
# Handle rate limiting
|
||||
await self._handle_rate_limit(response, endpoint_key)
|
||||
|
||||
# Retry with exponential backoff
|
||||
if attempt < max_retries - 1:
|
||||
backoff = self._calculate_backoff(attempt, base_delay)
|
||||
_LOGGER.debug(
|
||||
"Rate limited, retrying in %.1f seconds (attempt %d/%d)",
|
||||
backoff, attempt + 1, max_retries
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
continue
|
||||
else:
|
||||
raise UpdateFailed(
|
||||
f"API rate limit exceeded after {max_retries} attempts"
|
||||
)
|
||||
|
||||
elif response.status == 500:
|
||||
error_text = await response.text()
|
||||
# Extract plain text from HTML if present
|
||||
if error_text.strip().startswith('<!doctype html>') or error_text.strip().startswith('<html'):
|
||||
# Just log that it's HTML, not the whole HTML
|
||||
_LOGGER.error(
|
||||
self._t("api_error_html", status=500, endpoint=endpoint_key)
|
||||
)
|
||||
else:
|
||||
# Log actual error text (truncated)
|
||||
_LOGGER.error(
|
||||
"API returned 500 for %s: %s",
|
||||
endpoint_key, error_text[:100]
|
||||
)
|
||||
|
||||
# Retry with backoff
|
||||
if attempt < max_retries - 1:
|
||||
backoff = self._calculate_backoff(attempt, base_delay)
|
||||
_LOGGER.debug(
|
||||
"Retrying after 500 error in %.1f seconds (attempt %d/%d)",
|
||||
backoff, attempt + 1, max_retries
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
continue
|
||||
else:
|
||||
raise UpdateFailed(
|
||||
f"API server error (500) for {endpoint_key} after {max_retries} attempts"
|
||||
)
|
||||
|
||||
elif response.status in (401, 403):
|
||||
raise UpdateFailed(
|
||||
f"Authentication failed (status {response.status}). Please check your API key."
|
||||
)
|
||||
|
||||
elif response.status == 404:
|
||||
raise UpdateFailed(
|
||||
f"API endpoint not found (404): {endpoint_key}"
|
||||
)
|
||||
|
||||
else:
|
||||
error_text = await response.text()
|
||||
# Clean HTML from error messages
|
||||
if error_text.strip().startswith('<!doctype html>') or error_text.strip().startswith('<html'):
|
||||
_LOGGER.error(
|
||||
self._t("api_error_html", status=response.status, endpoint=endpoint_key)
|
||||
)
|
||||
else:
|
||||
_LOGGER.error(
|
||||
"API error %d for %s: %s",
|
||||
response.status, endpoint_key, error_text[:100]
|
||||
)
|
||||
|
||||
# For other errors, retry with backoff
|
||||
if attempt < max_retries - 1:
|
||||
backoff = self._calculate_backoff(attempt, base_delay)
|
||||
await asyncio.sleep(backoff)
|
||||
continue
|
||||
else:
|
||||
raise UpdateFailed(
|
||||
f"API error {response.status} for {endpoint_key}"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError as err:
|
||||
last_exception = err
|
||||
_LOGGER.warning(
|
||||
"Timeout fetching from %s (attempt %d/%d)",
|
||||
endpoint_key, attempt + 1, max_retries
|
||||
)
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
backoff = self._calculate_backoff(attempt, base_delay)
|
||||
await asyncio.sleep(backoff)
|
||||
continue
|
||||
|
||||
except aiohttp.ClientError as err:
|
||||
last_exception = err
|
||||
_LOGGER.warning(
|
||||
"Network error fetching from %s: %s (attempt %d/%d)",
|
||||
endpoint_key, err, attempt + 1, max_retries
|
||||
)
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
backoff = self._calculate_backoff(attempt, base_delay)
|
||||
await asyncio.sleep(backoff)
|
||||
continue
|
||||
|
||||
except Exception as err:
|
||||
last_exception = err
|
||||
_LOGGER.exception(
|
||||
"Unexpected error fetching from %s: %s",
|
||||
endpoint_key, err
|
||||
)
|
||||
break
|
||||
|
||||
# All retries exhausted, raise the error
|
||||
if last_exception:
|
||||
raise UpdateFailed(
|
||||
f"Failed to fetch data from {endpoint_key} after {max_retries} attempts"
|
||||
) from last_exception
|
||||
|
||||
raise UpdateFailed(f"Failed to fetch data from {endpoint_key}")
|
||||
|
||||
async def fetch(
|
||||
self,
|
||||
url: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 20.0
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetch data with deduplication of concurrent requests."""
|
||||
# Check if there's already an in-flight request for this URL
|
||||
async with self._in_flight_lock:
|
||||
if url in self._in_flight:
|
||||
_LOGGER.debug("Deduplicating request for %s", url)
|
||||
# Wait for the existing request to complete
|
||||
try:
|
||||
return await self._in_flight[url]
|
||||
except Exception:
|
||||
# If the in-flight request failed, create a new one
|
||||
pass
|
||||
|
||||
# Create new request task
|
||||
task = asyncio.create_task(
|
||||
self._make_request(url, max_retries, base_delay)
|
||||
)
|
||||
self._in_flight[url] = task
|
||||
|
||||
try:
|
||||
result = await task
|
||||
return result
|
||||
finally:
|
||||
# Remove from in-flight requests
|
||||
async with self._in_flight_lock:
|
||||
self._in_flight.pop(url, None)
|
||||
Reference in New Issue
Block a user