You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
191 lines
6.8 KiB
191 lines
6.8 KiB
"""
|
|
base_agent.py — Abstract base class for all AI agents.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
import traceback
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Optional
|
|
|
|
import asyncpg
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseAgent(ABC):
|
|
"""
|
|
All agents inherit from this class.
|
|
|
|
Responsibilities:
|
|
- Poll agent_jobs table for work
|
|
- Claim jobs atomically
|
|
- Execute with exponential-backoff retries
|
|
- Log results / errors to agent_logs
|
|
"""
|
|
|
|
agent_type: str # Must be set by subclass
|
|
|
|
def __init__(self, pool: asyncpg.Pool, settings: Any) -> None:
|
|
self.pool = pool
|
|
self.settings = settings
|
|
self._log = logging.getLogger(f'agent.{self.agent_type}')
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public interface
|
|
# ------------------------------------------------------------------
|
|
|
|
async def run_forever(self, poll_interval: int = 10) -> None:
|
|
"""Poll for jobs indefinitely."""
|
|
self._log.info('Agent started (poll_interval=%ds)', poll_interval)
|
|
while True:
|
|
try:
|
|
job = await self._claim_job()
|
|
if job:
|
|
await self._execute(job)
|
|
else:
|
|
await asyncio.sleep(poll_interval)
|
|
except asyncio.CancelledError:
|
|
self._log.info('Agent shutting down')
|
|
return
|
|
except Exception as exc:
|
|
self._log.error('Unexpected error in agent loop: %s', exc, exc_info=True)
|
|
await asyncio.sleep(poll_interval)
|
|
|
|
async def enqueue(self, payload: dict, priority: int = 5, delay_seconds: int = 0) -> str:
|
|
"""Create a new job for this agent."""
|
|
import uuid
|
|
from datetime import datetime, timezone, timedelta
|
|
job_id = str(uuid.uuid4())
|
|
scheduled = datetime.now(timezone.utc)
|
|
if delay_seconds:
|
|
scheduled += timedelta(seconds=delay_seconds)
|
|
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO agent_jobs (id, agent_type, priority, payload, scheduled_for)
|
|
VALUES ($1::uuid, $2, $3, $4::jsonb, $5)
|
|
""",
|
|
job_id, self.agent_type, priority, json.dumps(payload), scheduled,
|
|
)
|
|
return job_id
|
|
|
|
# ------------------------------------------------------------------
|
|
# Abstract
|
|
# ------------------------------------------------------------------
|
|
|
|
@abstractmethod
|
|
async def process(self, job_id: str, payload: dict) -> dict:
|
|
"""Process a single job. Return result dict."""
|
|
...
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _claim_job(self) -> Optional[asyncpg.Record]:
|
|
"""Atomically claim the next pending job for this agent type."""
|
|
async with self.pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
"""
|
|
UPDATE agent_jobs
|
|
SET status = 'running', started_at = now()
|
|
WHERE id = (
|
|
SELECT id FROM agent_jobs
|
|
WHERE agent_type = $1
|
|
AND status = 'pending'
|
|
AND scheduled_for <= now()
|
|
AND retry_count < max_retries
|
|
ORDER BY priority ASC, scheduled_for ASC
|
|
LIMIT 1
|
|
FOR UPDATE SKIP LOCKED
|
|
)
|
|
RETURNING *
|
|
""",
|
|
self.agent_type,
|
|
)
|
|
return row
|
|
|
|
async def _execute(self, job: asyncpg.Record) -> None:
|
|
job_id = str(job['id'])
|
|
payload = dict(job['payload'] or {})
|
|
self._log.info('Processing job %s', job_id)
|
|
start = time.monotonic()
|
|
|
|
try:
|
|
result = await self.process(job_id, payload)
|
|
elapsed = time.monotonic() - start
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute(
|
|
"""
|
|
UPDATE agent_jobs
|
|
SET status = 'done', result = $2::jsonb, completed_at = now()
|
|
WHERE id = $1::uuid
|
|
""",
|
|
job_id, json.dumps(result or {}),
|
|
)
|
|
await self._log_event(job_id, 'info', f'Job done in {elapsed:.2f}s', result or {})
|
|
|
|
except Exception as exc:
|
|
elapsed = time.monotonic() - start
|
|
err_msg = str(exc)
|
|
self._log.error('Job %s failed: %s', job_id, err_msg, exc_info=True)
|
|
|
|
async with self.pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
'SELECT retry_count, max_retries FROM agent_jobs WHERE id = $1::uuid', job_id
|
|
)
|
|
retries = (row['retry_count'] or 0) + 1
|
|
max_retries = row['max_retries'] or 3
|
|
|
|
if retries < max_retries:
|
|
# Re-queue with exponential backoff
|
|
backoff = 2 ** retries
|
|
await conn.execute(
|
|
"""
|
|
UPDATE agent_jobs
|
|
SET status = 'pending',
|
|
retry_count = $2,
|
|
error = $3,
|
|
scheduled_for = now() + ($4 || ' seconds')::interval
|
|
WHERE id = $1::uuid
|
|
""",
|
|
job_id, retries, err_msg, str(backoff),
|
|
)
|
|
await self._log_event(job_id, 'warning',
|
|
f'Retry {retries}/{max_retries} in {backoff}s', {})
|
|
else:
|
|
await conn.execute(
|
|
"""
|
|
UPDATE agent_jobs
|
|
SET status = 'failed', error = $2, completed_at = now()
|
|
WHERE id = $1::uuid
|
|
""",
|
|
job_id, err_msg,
|
|
)
|
|
await self._log_event(job_id, 'error', f'Job permanently failed: {err_msg}', {})
|
|
|
|
async def _log_event(
|
|
self,
|
|
job_id: Optional[str],
|
|
level: str,
|
|
message: str,
|
|
metadata: dict,
|
|
) -> None:
|
|
try:
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO agent_logs (job_id, agent_type, level, message, metadata)
|
|
VALUES ($1::uuid, $2, $3, $4, $5::jsonb)
|
|
""",
|
|
job_id, self.agent_type, level, message, json.dumps(metadata),
|
|
)
|
|
except Exception as log_exc:
|
|
self._log.warning('Failed to write agent log: %s', log_exc)
|