One more attempt with Claude's help to close connections

This commit is contained in:
Raghotham Murthy 2025-10-07 12:28:31 -07:00
parent fa100c77fd
commit a424815804
16 changed files with 132 additions and 38 deletions

View file

@ -315,6 +315,7 @@ class MetaReferenceAgentsImpl(Agents):
async def shutdown(self) -> None:
await self.persistence_store.close()
await self.responses_store.shutdown()
# OpenAI responses
async def get_openai_response(

View file

@ -62,7 +62,8 @@ class LocalfsFilesImpl(Files):
)
async def shutdown(self) -> None:
pass
if self.sql_store:
await self.sql_store.close()
def _generate_file_id(self) -> str:
"""Generate a unique file ID for OpenAI API."""

View file

@ -181,7 +181,8 @@ class S3FilesImpl(Files):
)
async def shutdown(self) -> None:
pass
if self._sql_store:
await self._sql_store.close()
@property
def client(self) -> boto3.client:

View file

@ -74,19 +74,21 @@ class InferenceStore:
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
async def shutdown(self) -> None:
if not self._worker_tasks:
return
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
if self._worker_tasks:
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
if self.sql_store:
await self.sql_store.close()
async def flush(self) -> None:
"""Wait for all queued writes to complete. Useful for testing."""

View file

@ -96,19 +96,21 @@ class ResponsesStore:
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
async def shutdown(self) -> None:
if not self._worker_tasks:
return
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
if self._worker_tasks:
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
if self.sql_store:
await self.sql_store.close()
async def flush(self) -> None:
"""Wait for all queued writes to complete. Useful for testing."""

View file

@ -126,3 +126,9 @@ class SqlStore(Protocol):
:param nullable: Whether the column should be nullable (default: True)
"""
pass
async def close(self) -> None:
"""
Close any persistent database connections.
"""
pass

View file

@ -197,6 +197,10 @@ class AuthorizedSqlStore:
"""Delete rows with automatic access control filtering."""
await self.sql_store.delete(table, where)
async def close(self) -> None:
"""Close the underlying SQL store connection."""
await self.sql_store.close()
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
"""Build SQL WHERE clause for access control filtering.

View file

@ -311,3 +311,11 @@ class SqlAlchemySqlStoreImpl(SqlStore):
# The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}")
pass
async def close(self) -> None:
"""Close the database engine and all connections."""
if hasattr(self, "async_session"):
# Get the engine from the session maker
engine = self.async_session.kw.get("bind")
if engine:
await engine.dispose()