mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
feat(testing): remove SQLite dependency from inference recorder (#3254)
Recording files use a predictable naming format, making the SQLite index redundant. The binary SQLite file was causing frequent git conflicts. Simplify by calculating file paths directly from request hashes. Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
1eb1ac0f41
commit
7ca8233889
3 changed files with 2 additions and 57 deletions
|
@ -9,7 +9,6 @@ from __future__ import annotations # for forward references
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
@ -125,28 +124,13 @@ class ResponseStorage:
|
||||||
def __init__(self, test_dir: Path):
|
def __init__(self, test_dir: Path):
|
||||||
self.test_dir = test_dir
|
self.test_dir = test_dir
|
||||||
self.responses_dir = self.test_dir / "responses"
|
self.responses_dir = self.test_dir / "responses"
|
||||||
self.db_path = self.test_dir / "index.sqlite"
|
|
||||||
|
|
||||||
self._ensure_directories()
|
self._ensure_directories()
|
||||||
self._init_database()
|
|
||||||
|
|
||||||
def _ensure_directories(self):
|
def _ensure_directories(self):
|
||||||
self.test_dir.mkdir(parents=True, exist_ok=True)
|
self.test_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self.responses_dir.mkdir(exist_ok=True)
|
self.responses_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
def _init_database(self):
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS recordings (
|
|
||||||
request_hash TEXT PRIMARY KEY,
|
|
||||||
response_file TEXT,
|
|
||||||
endpoint TEXT,
|
|
||||||
model TEXT,
|
|
||||||
timestamp TEXT,
|
|
||||||
is_streaming BOOLEAN
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
|
|
||||||
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
|
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
|
||||||
"""Store a request/response pair."""
|
"""Store a request/response pair."""
|
||||||
# Generate unique response filename
|
# Generate unique response filename
|
||||||
|
@ -169,34 +153,9 @@ class ResponseStorage:
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
f.flush()
|
f.flush()
|
||||||
|
|
||||||
# Update SQLite index
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute(
|
|
||||||
"""
|
|
||||||
INSERT OR REPLACE INTO recordings
|
|
||||||
(request_hash, response_file, endpoint, model, timestamp, is_streaming)
|
|
||||||
VALUES (?, ?, ?, ?, datetime('now'), ?)
|
|
||||||
""",
|
|
||||||
(
|
|
||||||
request_hash,
|
|
||||||
response_file,
|
|
||||||
request.get("endpoint", ""),
|
|
||||||
request.get("model", ""),
|
|
||||||
response.get("is_streaming", False),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
|
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
|
||||||
"""Find a recorded response by request hash."""
|
"""Find a recorded response by request hash."""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
response_file = f"{request_hash[:12]}.json"
|
||||||
result = conn.execute(
|
|
||||||
"SELECT response_file FROM recordings WHERE request_hash = ?", (request_hash,)
|
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
return None
|
|
||||||
|
|
||||||
response_file = result[0]
|
|
||||||
response_path = self.responses_dir / response_file
|
response_path = self.responses_dir / response_file
|
||||||
|
|
||||||
if not response_path.exists():
|
if not response_path.exists():
|
||||||
|
|
Binary file not shown.
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
@ -133,7 +132,6 @@ class TestInferenceRecording:
|
||||||
# Test directory creation
|
# Test directory creation
|
||||||
assert storage.test_dir.exists()
|
assert storage.test_dir.exists()
|
||||||
assert storage.responses_dir.exists()
|
assert storage.responses_dir.exists()
|
||||||
assert storage.db_path.exists()
|
|
||||||
|
|
||||||
# Test storing and retrieving a recording
|
# Test storing and retrieving a recording
|
||||||
request_hash = "test_hash_123"
|
request_hash = "test_hash_123"
|
||||||
|
@ -147,15 +145,6 @@ class TestInferenceRecording:
|
||||||
|
|
||||||
storage.store_recording(request_hash, request_data, response_data)
|
storage.store_recording(request_hash, request_data, response_data)
|
||||||
|
|
||||||
# Verify SQLite record
|
|
||||||
with sqlite3.connect(storage.db_path) as conn:
|
|
||||||
result = conn.execute("SELECT * FROM recordings WHERE request_hash = ?", (request_hash,)).fetchone()
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result[0] == request_hash # request_hash
|
|
||||||
assert result[2] == "/v1/chat/completions" # endpoint
|
|
||||||
assert result[3] == "llama3.2:3b" # model
|
|
||||||
|
|
||||||
# Verify file storage and retrieval
|
# Verify file storage and retrieval
|
||||||
retrieved = storage.find_recording(request_hash)
|
retrieved = storage.find_recording(request_hash)
|
||||||
assert retrieved is not None
|
assert retrieved is not None
|
||||||
|
@ -185,10 +174,7 @@ class TestInferenceRecording:
|
||||||
|
|
||||||
# Verify recording was stored
|
# Verify recording was stored
|
||||||
storage = ResponseStorage(temp_storage_dir)
|
storage = ResponseStorage(temp_storage_dir)
|
||||||
with sqlite3.connect(storage.db_path) as conn:
|
assert storage.responses_dir.exists()
|
||||||
recordings = conn.execute("SELECT COUNT(*) FROM recordings").fetchone()[0]
|
|
||||||
|
|
||||||
assert recordings == 1
|
|
||||||
|
|
||||||
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||||
"""Test that replay mode returns stored responses without making real calls."""
|
"""Test that replay mode returns stored responses without making real calls."""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue