mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
minimize change
This commit is contained in:
parent
ed78090b8e
commit
06d02bf3de
23 changed files with 161 additions and 197 deletions
|
|
@ -108,8 +108,3 @@ class QuotaMiddleware:
|
|||
)
|
||||
body = json.dumps({"error": {"message": message}}).encode()
|
||||
await send({"type": "http.response.body", "body": body})
|
||||
|
||||
async def close(self):
|
||||
"""Close the KV store connection."""
|
||||
if self.kv is not None:
|
||||
await self.kv.close()
|
||||
|
|
|
|||
|
|
@ -314,8 +314,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
return paginate_records(session_dicts, start_index, limit)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await self.persistence_store.close()
|
||||
await self.responses_store.shutdown()
|
||||
pass
|
||||
|
||||
# OpenAI responses
|
||||
async def get_openai_response(
|
||||
|
|
|
|||
|
|
@ -129,7 +129,6 @@ class ReferenceBatchesImpl(Batches):
|
|||
# don't cancel tasks - just let them stop naturally on shutdown
|
||||
# cancelling would mark batches as "cancelled" in the database
|
||||
logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
|
||||
await self.kvstore.close()
|
||||
|
||||
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
|
||||
async def create_batch(
|
||||
|
|
|
|||
|
|
@ -64,8 +64,7 @@ class MetaReferenceEvalImpl(
|
|||
benchmark = Benchmark.model_validate_json(benchmark)
|
||||
self.benchmarks[benchmark.identifier] = benchmark
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await self.kvstore.close()
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def register_benchmark(self, task_def: Benchmark) -> None:
|
||||
# Store in kvstore
|
||||
|
|
|
|||
|
|
@ -62,8 +62,7 @@ class LocalfsFilesImpl(Files):
|
|||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.sql_store:
|
||||
await self.sql_store.close()
|
||||
pass
|
||||
|
||||
def _generate_file_id(self) -> str:
|
||||
"""Generate a unique file ID for OpenAI API."""
|
||||
|
|
|
|||
|
|
@ -181,8 +181,7 @@ class S3FilesImpl(Files):
|
|||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._sql_store:
|
||||
await self._sql_store.close()
|
||||
pass
|
||||
|
||||
@property
|
||||
def client(self) -> boto3.client:
|
||||
|
|
|
|||
|
|
@ -74,21 +74,19 @@ class InferenceStore:
|
|||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._worker_tasks:
|
||||
if self._queue is not None:
|
||||
await self._queue.join()
|
||||
for t in self._worker_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
for t in self._worker_tasks:
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._worker_tasks.clear()
|
||||
|
||||
if self.sql_store:
|
||||
await self.sql_store.close()
|
||||
if not self._worker_tasks:
|
||||
return
|
||||
if self._queue is not None:
|
||||
await self._queue.join()
|
||||
for t in self._worker_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
for t in self._worker_tasks:
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._worker_tasks.clear()
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Wait for all queued writes to complete. Useful for testing."""
|
||||
|
|
|
|||
|
|
@ -19,7 +19,3 @@ class KVStore(Protocol):
|
|||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close any persistent connections. Optional method for cleanup."""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -43,10 +43,6 @@ class InmemoryKVStoreImpl(KVStore):
|
|||
async def delete(self, key: str) -> None:
|
||||
del self._store[key]
|
||||
|
||||
async def close(self) -> None:
|
||||
"""No-op for in-memory store."""
|
||||
pass
|
||||
|
||||
|
||||
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
|
||||
if config.type == KVStoreType.redis.value:
|
||||
|
|
|
|||
|
|
@ -37,72 +37,138 @@ class SqliteKVStoreImpl(KVStore):
|
|||
if db_dir: # Only create if there's a directory component
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
|
||||
# Create persistent connection for all databases
|
||||
self._conn = await aiosqlite.connect(self.db_path)
|
||||
await self._conn.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
expiration TIMESTAMP
|
||||
# Only use persistent connection for in-memory databases
|
||||
# File-based databases use connection-per-operation to avoid hangs
|
||||
if self._is_memory_db():
|
||||
self._conn = await aiosqlite.connect(self.db_path)
|
||||
await self._conn.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
expiration TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
"""
|
||||
)
|
||||
await self._conn.commit()
|
||||
await self._conn.commit()
|
||||
else:
|
||||
# For file-based databases, just create the table
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
expiration TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def close(self):
|
||||
"""Close the persistent connection."""
|
||||
"""Close the persistent connection (only for in-memory databases)."""
|
||||
if self._conn:
|
||||
await self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
@property
|
||||
def conn(self) -> aiosqlite.Connection:
|
||||
"""Get the connection, raising an error if not initialized."""
|
||||
if self._conn is None:
|
||||
raise RuntimeError("Connection not initialized. Call initialize() first.")
|
||||
return self._conn
|
||||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
await self.conn.execute(
|
||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||
(key, value, expiration),
|
||||
)
|
||||
await self.conn.commit()
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
await self._conn.execute(
|
||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||
(key, value, expiration),
|
||||
)
|
||||
await self._conn.commit()
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||
(key, value, expiration),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
async with self.conn.execute(
|
||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
value, expiration = row
|
||||
if not isinstance(value, str):
|
||||
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
|
||||
return None
|
||||
return value
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
async with self._conn.execute(
|
||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
value, expiration = row
|
||||
if not isinstance(value, str):
|
||||
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
|
||||
return None
|
||||
return value
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
value, expiration = row
|
||||
if not isinstance(value, str):
|
||||
logger.warning(f"Expected string value for key {key}, got {type(value)}, returning None")
|
||||
return None
|
||||
return value
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
await self.conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await self.conn.commit()
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
await self._conn.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await self._conn.commit()
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await db.commit()
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
async with self.conn.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
) as cursor:
|
||||
result = []
|
||||
async for row in cursor:
|
||||
_, value, _ = row
|
||||
result.append(value)
|
||||
return result
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
async with self._conn.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
) as cursor:
|
||||
result = []
|
||||
async for row in cursor:
|
||||
_, value, _ = row
|
||||
result.append(value)
|
||||
return result
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
) as cursor:
|
||||
result = []
|
||||
async for row in cursor:
|
||||
_, value, _ = row
|
||||
result.append(value)
|
||||
return result
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
"""Get all keys in the given range."""
|
||||
cursor = await self.conn.execute(
|
||||
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
if self._conn:
|
||||
# In-memory database with persistent connection
|
||||
cursor = await self._conn.execute(
|
||||
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
else:
|
||||
# File-based database with connection per operation
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
cursor = await db.execute(
|
||||
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
|
|
|||
|
|
@ -96,21 +96,19 @@ class ResponsesStore:
|
|||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._worker_tasks:
|
||||
if self._queue is not None:
|
||||
await self._queue.join()
|
||||
for t in self._worker_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
for t in self._worker_tasks:
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._worker_tasks.clear()
|
||||
|
||||
if self.sql_store:
|
||||
await self.sql_store.close()
|
||||
if not self._worker_tasks:
|
||||
return
|
||||
if self._queue is not None:
|
||||
await self._queue.join()
|
||||
for t in self._worker_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
for t in self._worker_tasks:
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._worker_tasks.clear()
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Wait for all queued writes to complete. Useful for testing."""
|
||||
|
|
|
|||
|
|
@ -126,9 +126,3 @@ class SqlStore(Protocol):
|
|||
:param nullable: Whether the column should be nullable (default: True)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close any persistent database connections.
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -197,10 +197,6 @@ class AuthorizedSqlStore:
|
|||
"""Delete rows with automatic access control filtering."""
|
||||
await self.sql_store.delete(table, where)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying SQL store connection."""
|
||||
await self.sql_store.close()
|
||||
|
||||
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
|
||||
"""Build SQL WHERE clause for access control filtering.
|
||||
|
||||
|
|
|
|||
|
|
@ -311,11 +311,3 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
# The table creation will handle adding the column
|
||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the database engine and all connections."""
|
||||
if hasattr(self, "async_session"):
|
||||
# Get the engine from the session maker
|
||||
engine = self.async_session.kw.get("bind")
|
||||
if engine:
|
||||
await engine.dispose()
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ async def files_provider(tmp_path):
|
|||
provider = LocalfsFilesImpl(config, default_policy())
|
||||
await provider.initialize()
|
||||
yield provider
|
||||
await provider.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ async def sqlite_kvstore(tmp_path):
|
|||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||
await kvstore.initialize()
|
||||
yield kvstore
|
||||
await kvstore.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
|
|
|||
|
|
@ -28,5 +28,3 @@ async def temp_prompt_store(tmp_path_factory):
|
|||
store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))
|
||||
|
||||
yield store
|
||||
|
||||
await store.kvstore.close()
|
||||
|
|
|
|||
|
|
@ -46,9 +46,12 @@ async def test_initialize_index(vector_index):
|
|||
|
||||
|
||||
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
|
||||
vector_index.delete()
|
||||
vector_index.initialize()
|
||||
await vector_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
|
||||
assert resp.chunks[0].content == sample_chunks[0].content
|
||||
vector_index.delete()
|
||||
|
||||
|
||||
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
|
||||
|
|
|
|||
|
|
@ -52,21 +52,17 @@ def auth_app(tmp_path, request):
|
|||
db_path = tmp_path / f"quota_{request.node.name}.db"
|
||||
quota = build_quota_config(db_path)
|
||||
|
||||
quota_middleware = QuotaMiddleware(
|
||||
inner_app,
|
||||
kv_config=quota.kvstore,
|
||||
anonymous_max_requests=quota.anonymous_max_requests,
|
||||
authenticated_max_requests=quota.authenticated_max_requests,
|
||||
window_seconds=86400,
|
||||
app = InjectClientIDMiddleware(
|
||||
QuotaMiddleware(
|
||||
inner_app,
|
||||
kv_config=quota.kvstore,
|
||||
anonymous_max_requests=quota.anonymous_max_requests,
|
||||
authenticated_max_requests=quota.authenticated_max_requests,
|
||||
window_seconds=86400,
|
||||
),
|
||||
client_id=f"client_{request.node.name}",
|
||||
)
|
||||
app = InjectClientIDMiddleware(quota_middleware, client_id=f"client_{request.node.name}")
|
||||
|
||||
yield app
|
||||
|
||||
# Cleanup
|
||||
import asyncio
|
||||
|
||||
asyncio.run(quota_middleware.close())
|
||||
return app
|
||||
|
||||
|
||||
def test_authenticated_quota_allows_up_to_limit(auth_app):
|
||||
|
|
@ -85,8 +81,6 @@ def test_authenticated_quota_blocks_after_limit(auth_app):
|
|||
|
||||
|
||||
def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
|
||||
import asyncio
|
||||
|
||||
inner_app = FastAPI()
|
||||
|
||||
@inner_app.get("/test")
|
||||
|
|
@ -107,12 +101,8 @@ def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
|
|||
client = TestClient(app)
|
||||
assert client.get("/test").status_code == 200
|
||||
|
||||
asyncio.run(app.close())
|
||||
|
||||
|
||||
def test_anonymous_quota_blocks_after_limit(tmp_path, request):
|
||||
import asyncio
|
||||
|
||||
inner_app = FastAPI()
|
||||
|
||||
@inner_app.get("/test")
|
||||
|
|
@ -135,5 +125,3 @@ def test_anonymous_quota_blocks_after_limit(tmp_path, request):
|
|||
resp = client.get("/test")
|
||||
assert resp.status_code == 429
|
||||
assert resp.json()["error"]["message"] == "Quota exceeded"
|
||||
|
||||
asyncio.run(app.close())
|
||||
|
|
|
|||
|
|
@ -89,8 +89,6 @@ async def test_inference_store_pagination_basic():
|
|||
assert result3.data[0].id == "zebra-task"
|
||||
assert result3.has_more is False
|
||||
|
||||
await store.sql_store.close()
|
||||
|
||||
|
||||
async def test_inference_store_pagination_ascending():
|
||||
"""Test pagination with ascending order."""
|
||||
|
|
@ -128,8 +126,6 @@ async def test_inference_store_pagination_ascending():
|
|||
assert result2.data[0].id == "charlie-task"
|
||||
assert result2.has_more is True
|
||||
|
||||
await store.sql_store.close()
|
||||
|
||||
|
||||
async def test_inference_store_pagination_with_model_filter():
|
||||
"""Test pagination combined with model filtering."""
|
||||
|
|
@ -170,8 +166,6 @@ async def test_inference_store_pagination_with_model_filter():
|
|||
assert result2.data[0].model == "model-a"
|
||||
assert result2.has_more is False
|
||||
|
||||
await store.sql_store.close()
|
||||
|
||||
|
||||
async def test_inference_store_pagination_invalid_after():
|
||||
"""Test error handling for invalid 'after' parameter."""
|
||||
|
|
@ -184,8 +178,6 @@ async def test_inference_store_pagination_invalid_after():
|
|||
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
|
||||
await store.list_chat_completions(after="non-existent", limit=2)
|
||||
|
||||
await store.sql_store.close()
|
||||
|
||||
|
||||
async def test_inference_store_pagination_no_limit():
|
||||
"""Test pagination behavior when no limit is specified."""
|
||||
|
|
@ -216,5 +208,3 @@ async def test_inference_store_pagination_no_limit():
|
|||
assert result.data[0].id == "beta-second" # Most recent first
|
||||
assert result.data[1].id == "omega-first"
|
||||
assert result.has_more is False
|
||||
|
||||
await store.sql_store.close()
|
||||
|
|
|
|||
|
|
@ -98,8 +98,6 @@ async def test_responses_store_pagination_basic():
|
|||
assert result3.data[0].id == "zebra-resp"
|
||||
assert result3.has_more is False
|
||||
|
||||
await store.sql_store.close()
|
||||
|
||||
|
||||
async def test_responses_store_pagination_ascending():
|
||||
"""Test pagination with ascending order."""
|
||||
|
|
@ -138,8 +136,6 @@ async def test_responses_store_pagination_ascending():
|
|||
assert result2.data[0].id == "charlie-resp"
|
||||
assert result2.has_more is True
|
||||
|
||||
await store.sql_store.close()
|
||||
|
||||
|
||||
async def test_responses_store_pagination_with_model_filter():
|
||||
"""Test pagination combined with model filtering."""
|
||||
|
|
@ -181,8 +177,6 @@ async def test_responses_store_pagination_with_model_filter():
|
|||
assert result2.data[0].model == "model-a"
|
||||
assert result2.has_more is False
|
||||
|
||||
await store.sql_store.close()
|
||||
|
||||
|
||||
async def test_responses_store_pagination_invalid_after():
|
||||
"""Test error handling for invalid 'after' parameter."""
|
||||
|
|
@ -195,8 +189,6 @@ async def test_responses_store_pagination_invalid_after():
|
|||
with pytest.raises(ValueError, match="Record with id.*'non-existent' not found in table 'openai_responses'"):
|
||||
await store.list_responses(after="non-existent", limit=2)
|
||||
|
||||
await store.sql_store.close()
|
||||
|
||||
|
||||
async def test_responses_store_pagination_no_limit():
|
||||
"""Test pagination behavior when no limit is specified."""
|
||||
|
|
@ -229,8 +221,6 @@ async def test_responses_store_pagination_no_limit():
|
|||
assert result.data[1].id == "omega-resp"
|
||||
assert result.has_more is False
|
||||
|
||||
await store.sql_store.close()
|
||||
|
||||
|
||||
async def test_responses_store_get_response_object():
|
||||
"""Test retrieving a single response object."""
|
||||
|
|
@ -259,8 +249,6 @@ async def test_responses_store_get_response_object():
|
|||
with pytest.raises(ValueError, match="Response with id non-existent not found"):
|
||||
await store.get_response_object("non-existent")
|
||||
|
||||
await store.sql_store.close()
|
||||
|
||||
|
||||
async def test_responses_store_input_items_pagination():
|
||||
"""Test pagination functionality for input items."""
|
||||
|
|
@ -342,8 +330,6 @@ async def test_responses_store_input_items_pagination():
|
|||
with pytest.raises(ValueError, match="Cannot specify both 'before' and 'after' parameters"):
|
||||
await store.list_response_input_items("test-resp", before="some-id", after="other-id")
|
||||
|
||||
await store.sql_store.close()
|
||||
|
||||
|
||||
async def test_responses_store_input_items_before_pagination():
|
||||
"""Test before pagination functionality for input items."""
|
||||
|
|
@ -404,5 +390,3 @@ async def test_responses_store_input_items_before_pagination():
|
|||
ValueError, match="Input item with id 'non-existent' not found for response 'test-resp-before'"
|
||||
):
|
||||
await store.list_response_input_items("test-resp-before", before="non-existent")
|
||||
|
||||
await store.sql_store.close()
|
||||
|
|
|
|||
|
|
@ -64,9 +64,6 @@ async def test_sqlite_sqlstore():
|
|||
assert result.data == [{"id": 12, "name": "test12"}]
|
||||
assert result.has_more is False
|
||||
|
||||
# cleanup
|
||||
await sqlstore.close()
|
||||
|
||||
|
||||
async def test_sqlstore_pagination_basic():
|
||||
"""Test basic pagination functionality at the SQL store level."""
|
||||
|
|
@ -131,8 +128,6 @@ async def test_sqlstore_pagination_basic():
|
|||
assert result3.data[0]["id"] == "zebra"
|
||||
assert result3.has_more is False
|
||||
|
||||
await store.close()
|
||||
|
||||
|
||||
async def test_sqlstore_pagination_with_filter():
|
||||
"""Test pagination with WHERE conditions."""
|
||||
|
|
@ -185,8 +180,6 @@ async def test_sqlstore_pagination_with_filter():
|
|||
assert result2.data[0]["id"] == "xyz"
|
||||
assert result2.has_more is False
|
||||
|
||||
await store.close()
|
||||
|
||||
|
||||
async def test_sqlstore_pagination_ascending_order():
|
||||
"""Test pagination with ascending order."""
|
||||
|
|
@ -235,8 +228,6 @@ async def test_sqlstore_pagination_ascending_order():
|
|||
assert result2.data[0]["id"] == "alpha"
|
||||
assert result2.has_more is True
|
||||
|
||||
await store.close()
|
||||
|
||||
|
||||
async def test_sqlstore_pagination_multi_column_ordering_error():
|
||||
"""Test that multi-column ordering raises an error when using cursor pagination."""
|
||||
|
|
@ -274,8 +265,6 @@ async def test_sqlstore_pagination_multi_column_ordering_error():
|
|||
assert len(result.data) == 1
|
||||
assert result.data[0]["id"] == "task1"
|
||||
|
||||
await store.close()
|
||||
|
||||
|
||||
async def test_sqlstore_pagination_cursor_requires_order_by():
|
||||
"""Test that cursor pagination requires order_by parameter."""
|
||||
|
|
@ -293,8 +282,6 @@ async def test_sqlstore_pagination_cursor_requires_order_by():
|
|||
cursor=("id", "task1"),
|
||||
)
|
||||
|
||||
await store.close()
|
||||
|
||||
|
||||
async def test_sqlstore_pagination_error_handling():
|
||||
"""Test error handling for invalid columns and cursor IDs."""
|
||||
|
|
@ -427,8 +414,6 @@ async def test_where_operator_edge_cases():
|
|||
with pytest.raises(ValueError, match="Unsupported operator"):
|
||||
await store.fetch_all("events", where={"ts": {"!=": base}})
|
||||
|
||||
await store.close()
|
||||
|
||||
|
||||
async def test_sqlstore_pagination_custom_key_column():
|
||||
"""Test pagination with custom primary key column (not 'id')."""
|
||||
|
|
@ -478,5 +463,3 @@ async def test_sqlstore_pagination_custom_key_column():
|
|||
assert len(result2.data) == 1
|
||||
assert result2.data[0]["uuid"] == "uuid-alpha"
|
||||
assert result2.has_more is False
|
||||
|
||||
await store.close()
|
||||
|
|
|
|||
|
|
@ -77,8 +77,6 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
|||
assert row is not None
|
||||
assert row["title"] == "User Document"
|
||||
|
||||
await base_sqlstore.close()
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||
|
|
@ -165,8 +163,6 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
|||
f"Difference: SQL only: {sql_ids - policy_ids}, Policy only: {policy_ids - sql_ids}"
|
||||
)
|
||||
|
||||
await base_sqlstore.close()
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
|
||||
|
|
@ -215,5 +211,3 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
|
|||
# Third item should have null attributes (no authenticated user)
|
||||
assert result.data[2]["id"] == "item3"
|
||||
assert result.data[2]["access_attributes"] is None
|
||||
|
||||
await base_sqlstore.close()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue