minimize change

This commit is contained in:
Raghotham Murthy 2025-10-11 01:32:38 -07:00
parent ed78090b8e
commit 06d02bf3de
23 changed files with 161 additions and 197 deletions

View file

@ -314,8 +314,7 @@ class MetaReferenceAgentsImpl(Agents):
return paginate_records(session_dicts, start_index, limit)
async def shutdown(self) -> None:
await self.persistence_store.close()
await self.responses_store.shutdown()
pass
# OpenAI responses
async def get_openai_response(

View file

@ -129,7 +129,6 @@ class ReferenceBatchesImpl(Batches):
# don't cancel tasks - just let them stop naturally on shutdown
# cancelling would mark batches as "cancelled" in the database
logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
await self.kvstore.close()
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
async def create_batch(

View file

@ -64,8 +64,7 @@ class MetaReferenceEvalImpl(
benchmark = Benchmark.model_validate_json(benchmark)
self.benchmarks[benchmark.identifier] = benchmark
async def shutdown(self) -> None:
await self.kvstore.close()
async def shutdown(self) -> None: ...
async def register_benchmark(self, task_def: Benchmark) -> None:
# Store in kvstore

View file

@ -62,8 +62,7 @@ class LocalfsFilesImpl(Files):
)
async def shutdown(self) -> None:
if self.sql_store:
await self.sql_store.close()
pass
def _generate_file_id(self) -> str:
"""Generate a unique file ID for OpenAI API."""

View file

@ -181,8 +181,7 @@ class S3FilesImpl(Files):
)
async def shutdown(self) -> None:
if self._sql_store:
await self._sql_store.close()
pass
@property
def client(self) -> boto3.client:

View file

@ -74,21 +74,19 @@ class InferenceStore:
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
async def shutdown(self) -> None:
if self._worker_tasks:
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
if self.sql_store:
await self.sql_store.close()
if not self._worker_tasks:
return
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
async def flush(self) -> None:
"""Wait for all queued writes to complete. Useful for testing."""

View file

@ -19,7 +19,3 @@ class KVStore(Protocol):
async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...
async def close(self) -> None:
"""Close any persistent connections. Optional method for cleanup."""
...

View file

@ -43,10 +43,6 @@ class InmemoryKVStoreImpl(KVStore):
async def delete(self, key: str) -> None:
del self._store[key]
async def close(self) -> None:
"""No-op for in-memory store."""
pass
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
if config.type == KVStoreType.redis.value:

View file

@ -37,72 +37,138 @@ class SqliteKVStoreImpl(KVStore):
if db_dir: # Only create if there's a directory component
os.makedirs(db_dir, exist_ok=True)
# Create persistent connection for all databases
self._conn = await aiosqlite.connect(self.db_path)
await self._conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key TEXT PRIMARY KEY,
value TEXT,
expiration TIMESTAMP
# Only use persistent connection for in-memory databases
# File-based databases use connection-per-operation to avoid hangs
if self._is_memory_db():
self._conn = await aiosqlite.connect(self.db_path)
await self._conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key TEXT PRIMARY KEY,
value TEXT,
expiration TIMESTAMP
)
"""
)
"""
)
await self._conn.commit()
await self._conn.commit()
else:
# For file-based databases, just create the table
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
key TEXT PRIMARY KEY,
value TEXT,
expiration TIMESTAMP
)
"""
)
await db.commit()
async def close(self):
"""Close the persistent connection."""
"""Close the persistent connection (only for in-memory databases)."""
if self._conn:
await self._conn.close()
self._conn = None
@property
def conn(self) -> aiosqlite.Connection:
"""Get the connection, raising an error if not initialized."""
if self._conn is None:
raise RuntimeError("Connection not initialized. Call initialize() first.")
return self._conn
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
await self.conn.execute(
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
(key, value, expiration),
)
await self.conn.commit()
if self._conn:
# In-memory database with persistent connection
await self._conn.execute(
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
(key, value, expiration),
)
await self._conn.commit()
else:
# File-based database with connection per operation
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
(key, value, expiration),
)
await db.commit()
async def get(self, key: str) -> str | None:
async with self.conn.execute(
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
) as cursor:
row = await cursor.fetchone()
if row is None:
return None
value, expiration = row
if not isinstance(value, str):
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
return None
return value
if self._conn:
# In-memory database with persistent connection
async with self._conn.execute(
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
) as cursor:
row = await cursor.fetchone()
if row is None:
return None
value, expiration = row
if not isinstance(value, str):
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
return None
return value
else:
# File-based database with connection per operation
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
) as cursor:
row = await cursor.fetchone()
if row is None:
return None
value, expiration = row
if not isinstance(value, str):
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
return None
return value
async def delete(self, key: str) -> None:
await self.conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await self.conn.commit()
if self._conn:
# In-memory database with persistent connection
await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await self._conn.commit()
else:
# File-based database with connection per operation
async with aiosqlite.connect(self.db_path) as db:
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit()
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
async with self.conn.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
) as cursor:
result = []
async for row in cursor:
_, value, _ = row
result.append(value)
return result
if self._conn:
# In-memory database with persistent connection
async with self._conn.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
) as cursor:
result = []
async for row in cursor:
_, value, _ = row
result.append(value)
return result
else:
# File-based database with connection per operation
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
) as cursor:
result = []
async for row in cursor:
_, value, _ = row
result.append(value)
return result
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
cursor = await self.conn.execute(
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
)
rows = await cursor.fetchall()
return [row[0] for row in rows]
if self._conn:
# In-memory database with persistent connection
cursor = await self._conn.execute(
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
)
rows = await cursor.fetchall()
return [row[0] for row in rows]
else:
# File-based database with connection per operation
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
)
rows = await cursor.fetchall()
return [row[0] for row in rows]

View file

@ -96,21 +96,19 @@ class ResponsesStore:
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
async def shutdown(self) -> None:
if self._worker_tasks:
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
if self.sql_store:
await self.sql_store.close()
if not self._worker_tasks:
return
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
async def flush(self) -> None:
"""Wait for all queued writes to complete. Useful for testing."""

View file

@ -126,9 +126,3 @@ class SqlStore(Protocol):
:param nullable: Whether the column should be nullable (default: True)
"""
pass
async def close(self) -> None:
"""
Close any persistent database connections.
"""
pass

View file

@ -197,10 +197,6 @@ class AuthorizedSqlStore:
"""Delete rows with automatic access control filtering."""
await self.sql_store.delete(table, where)
async def close(self) -> None:
"""Close the underlying SQL store connection."""
await self.sql_store.close()
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
"""Build SQL WHERE clause for access control filtering.

View file

@ -311,11 +311,3 @@ class SqlAlchemySqlStoreImpl(SqlStore):
# The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}")
pass
async def close(self) -> None:
"""Close the database engine and all connections."""
if hasattr(self, "async_session"):
# Get the engine from the session maker
engine = self.async_session.kw.get("bind")
if engine:
await engine.dispose()