mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 00:52:26 +00:00
much simpler
This commit is contained in:
parent
e59c13f2b8
commit
481a893eb7
19 changed files with 6365 additions and 302 deletions
|
|
@ -4,13 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from __future__ import annotations
|
||||
from __future__ import annotations # for forward references
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
|
@ -28,78 +27,18 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
|
|||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url)
|
||||
endpoint = parsed.path
|
||||
normalized = {"method": method.upper(), "endpoint": parsed.path, "body": body}
|
||||
|
||||
# Create normalized request dict
|
||||
normalized: dict[str, Any] = {
|
||||
"method": method.upper(),
|
||||
"endpoint": endpoint,
|
||||
}
|
||||
|
||||
# Normalize body parameters
|
||||
if body:
|
||||
# Handle model parameter
|
||||
if "model" in body:
|
||||
normalized["model"] = body["model"]
|
||||
|
||||
# Handle messages (normalize whitespace)
|
||||
if "messages" in body:
|
||||
normalized_messages = []
|
||||
for msg in body["messages"]:
|
||||
normalized_msg = dict(msg)
|
||||
if "content" in normalized_msg and isinstance(normalized_msg["content"], str):
|
||||
# Normalize whitespace
|
||||
normalized_msg["content"] = " ".join(normalized_msg["content"].split())
|
||||
normalized_messages.append(normalized_msg)
|
||||
normalized["messages"] = normalized_messages
|
||||
|
||||
# Handle other parameters (sort for consistency)
|
||||
other_params = {}
|
||||
for key, value in body.items():
|
||||
if key not in ["model", "messages"]:
|
||||
if isinstance(value, float):
|
||||
# Round floats to 6 decimal places
|
||||
other_params[key] = round(value, 6)
|
||||
else:
|
||||
other_params[key] = value
|
||||
|
||||
if other_params:
|
||||
# Sort dictionary keys for consistent hashing
|
||||
normalized["parameters"] = dict(sorted(other_params.items()))
|
||||
|
||||
# Create hash
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||
|
||||
|
||||
def get_current_test_id() -> str:
|
||||
"""Extract test ID from pytest context or fall back to environment/generated ID."""
|
||||
# Try to get from pytest context
|
||||
try:
|
||||
import _pytest.fixtures
|
||||
|
||||
if hasattr(_pytest.fixtures, "_current_request") and _pytest.fixtures._current_request:
|
||||
request = _pytest.fixtures._current_request
|
||||
if hasattr(request, "node"):
|
||||
# Use the test node ID as our test identifier
|
||||
node_id: str = request.node.nodeid
|
||||
# Clean up the node ID to be filesystem-safe
|
||||
test_id = node_id.replace("/", "_").replace("::", "_").replace(".py", "")
|
||||
return test_id
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Fall back to environment-based or generated ID
|
||||
return os.environ.get("LLAMA_STACK_TEST_ID", f"test_{uuid.uuid4().hex[:8]}")
|
||||
|
||||
|
||||
def get_inference_mode() -> str:
|
||||
"""Get the inference recording mode from environment variables."""
|
||||
return os.environ.get("LLAMA_STACK_INFERENCE_MODE", "live").lower()
|
||||
|
||||
|
||||
def setup_inference_recording():
|
||||
"""Convenience function to set up inference recording based on environment variables."""
|
||||
mode = get_inference_mode()
|
||||
|
||||
if mode not in ["live", "record", "replay"]:
|
||||
|
|
@ -113,14 +52,14 @@ def setup_inference_recording():
|
|||
|
||||
return live_mode()
|
||||
|
||||
test_id = get_current_test_id()
|
||||
storage_dir = os.environ.get("LLAMA_STACK_RECORDING_DIR", str(Path.home() / ".llama" / "recordings"))
|
||||
if "LLAMA_STACK_RECORDING_DIR" not in os.environ:
|
||||
raise ValueError("LLAMA_STACK_RECORDING_DIR must be set for recording or replaying")
|
||||
storage_dir = os.environ["LLAMA_STACK_RECORDING_DIR"]
|
||||
|
||||
return inference_recording(mode=mode, test_id=test_id, storage_dir=storage_dir)
|
||||
return inference_recording(mode=mode, storage_dir=storage_dir)
|
||||
|
||||
|
||||
def _serialize_response(response: Any) -> Any:
|
||||
"""Serialize OpenAI response objects to JSON-compatible format."""
|
||||
if hasattr(response, "model_dump"):
|
||||
return response.model_dump()
|
||||
elif hasattr(response, "__dict__"):
|
||||
|
|
@ -130,19 +69,14 @@ def _serialize_response(response: Any) -> Any:
|
|||
|
||||
|
||||
def _deserialize_response(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Deserialize response data back to a dict format."""
|
||||
# For simplicity, just return the dict - this preserves all the data
|
||||
# The original response structure is sufficient for replaying
|
||||
return data
|
||||
|
||||
|
||||
class ResponseStorage:
|
||||
"""Handles SQLite index + JSON file storage/retrieval for inference recordings."""
|
||||
|
||||
def __init__(self, base_dir: Path, test_id: str):
|
||||
self.base_dir = base_dir
|
||||
self.test_id = test_id
|
||||
self.test_dir = base_dir / test_id
|
||||
def __init__(self, test_dir: Path):
|
||||
self.test_dir = test_dir
|
||||
self.responses_dir = self.test_dir / "responses"
|
||||
self.db_path = self.test_dir / "index.sqlite"
|
||||
|
||||
|
|
@ -234,37 +168,55 @@ class ResponseStorage:
|
|||
return cast(dict[str, Any], data)
|
||||
|
||||
|
||||
async def _patched_create_method(original_method, self, **kwargs):
|
||||
"""Patched version of OpenAI client create methods."""
|
||||
async def _patched_inference_method(original_method, self, client_type, method_name=None, **kwargs):
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == "live" or _current_storage is None:
|
||||
# Normal operation
|
||||
return await original_method(self, **kwargs)
|
||||
|
||||
# Get base URL from the client
|
||||
base_url = str(self._client.base_url)
|
||||
# Get base URL and endpoint based on client type
|
||||
if client_type == "openai":
|
||||
base_url = str(self._client.base_url)
|
||||
|
||||
# Determine endpoint based on the method's module/class path
|
||||
method_str = str(original_method)
|
||||
if "chat.completions" in method_str:
|
||||
endpoint = "/v1/chat/completions"
|
||||
elif "embeddings" in method_str:
|
||||
endpoint = "/v1/embeddings"
|
||||
elif "completions" in method_str:
|
||||
endpoint = "/v1/completions"
|
||||
else:
|
||||
# Fallback - try to guess from the self object
|
||||
if hasattr(self, "_resource") and hasattr(self._resource, "_resource"):
|
||||
resource_name = getattr(self._resource._resource, "_resource", "unknown")
|
||||
if "chat" in str(resource_name):
|
||||
endpoint = "/v1/chat/completions"
|
||||
elif "embeddings" in str(resource_name):
|
||||
endpoint = "/v1/embeddings"
|
||||
# Determine endpoint based on the method's module/class path
|
||||
method_str = str(original_method)
|
||||
if "chat.completions" in method_str:
|
||||
endpoint = "/v1/chat/completions"
|
||||
elif "embeddings" in method_str:
|
||||
endpoint = "/v1/embeddings"
|
||||
elif "completions" in method_str:
|
||||
endpoint = "/v1/completions"
|
||||
else:
|
||||
# Fallback - try to guess from the self object
|
||||
if hasattr(self, "_resource") and hasattr(self._resource, "_resource"):
|
||||
resource_name = getattr(self._resource._resource, "_resource", "unknown")
|
||||
if "chat" in str(resource_name):
|
||||
endpoint = "/v1/chat/completions"
|
||||
elif "embeddings" in str(resource_name):
|
||||
endpoint = "/v1/embeddings"
|
||||
else:
|
||||
endpoint = "/v1/completions"
|
||||
else:
|
||||
endpoint = "/v1/completions"
|
||||
|
||||
elif client_type == "ollama":
|
||||
# Get base URL from the client (Ollama client uses host attribute)
|
||||
base_url = getattr(self, "host", "http://localhost:11434")
|
||||
if not base_url.startswith("http"):
|
||||
base_url = f"http://{base_url}"
|
||||
|
||||
# Determine endpoint based on method name
|
||||
if method_name == "generate":
|
||||
endpoint = "/api/generate"
|
||||
elif method_name == "chat":
|
||||
endpoint = "/api/chat"
|
||||
elif method_name == "embed":
|
||||
endpoint = "/api/embeddings"
|
||||
else:
|
||||
endpoint = "/v1/completions"
|
||||
endpoint = f"/api/{method_name}"
|
||||
else:
|
||||
raise ValueError(f"Unknown client type: {client_type}")
|
||||
|
||||
url = base_url.rstrip("/") + endpoint
|
||||
|
||||
|
|
@ -276,15 +228,12 @@ async def _patched_create_method(original_method, self, **kwargs):
|
|||
request_hash = normalize_request(method, url, headers, body)
|
||||
|
||||
if _current_mode == "replay":
|
||||
# Try to find recorded response
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
if recording:
|
||||
# Return recorded response
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
# Handle streaming responses
|
||||
if recording["response"].get("is_streaming", False):
|
||||
# For streaming, we need to return an async iterator
|
||||
|
||||
async def replay_stream():
|
||||
for chunk in response_body:
|
||||
yield chunk
|
||||
|
|
@ -301,110 +250,8 @@ async def _patched_create_method(original_method, self, **kwargs):
|
|||
)
|
||||
|
||||
elif _current_mode == "record":
|
||||
# Make real request and record it
|
||||
response = await original_method(self, **kwargs)
|
||||
|
||||
# Store the recording
|
||||
request_data = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
"endpoint": endpoint,
|
||||
"model": body.get("model", ""),
|
||||
}
|
||||
|
||||
# Determine if this is a streaming request based on request parameters
|
||||
is_streaming = body.get("stream", False)
|
||||
|
||||
if is_streaming:
|
||||
# For streaming responses, we need to collect all chunks immediately before yielding
|
||||
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Store the recording immediately
|
||||
response_data = {"body": chunks, "is_streaming": True}
|
||||
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||
|
||||
# Return a generator that replays the stored chunks
|
||||
async def replay_recorded_stream():
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
return replay_recorded_stream()
|
||||
else:
|
||||
response_data = {"body": response, "is_streaming": False}
|
||||
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||
return response
|
||||
|
||||
else:
|
||||
return await original_method(self, **kwargs)
|
||||
|
||||
|
||||
async def _patched_ollama_method(original_method, self, method_name, **kwargs):
|
||||
"""Patched version of Ollama AsyncClient methods."""
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == "live" or _current_storage is None:
|
||||
# Normal operation
|
||||
return await original_method(self, **kwargs)
|
||||
|
||||
# Get base URL from the client (Ollama client uses host attribute)
|
||||
base_url = getattr(self, "host", "http://localhost:11434")
|
||||
if not base_url.startswith("http"):
|
||||
base_url = f"http://{base_url}"
|
||||
|
||||
# Determine endpoint based on method name
|
||||
if method_name == "generate":
|
||||
endpoint = "/api/generate"
|
||||
elif method_name == "chat":
|
||||
endpoint = "/api/chat"
|
||||
elif method_name == "embed":
|
||||
endpoint = "/api/embeddings"
|
||||
else:
|
||||
endpoint = f"/api/{method_name}"
|
||||
|
||||
url = base_url.rstrip("/") + endpoint
|
||||
|
||||
# Normalize request for matching
|
||||
method = "POST"
|
||||
headers = {}
|
||||
body = kwargs
|
||||
|
||||
request_hash = normalize_request(method, url, headers, body)
|
||||
|
||||
if _current_mode == "replay":
|
||||
# Try to find recorded response
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
if recording:
|
||||
# Return recorded response
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
# Handle streaming responses for Ollama
|
||||
if recording["response"].get("is_streaming", False):
|
||||
# For streaming, we need to return an async iterator
|
||||
async def replay_ollama_stream():
|
||||
for chunk in response_body:
|
||||
yield chunk
|
||||
|
||||
return replay_ollama_stream()
|
||||
else:
|
||||
return response_body
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"No recorded response found for request hash: {request_hash}\n"
|
||||
f"Endpoint: {endpoint}\n"
|
||||
f"Model: {body.get('model', 'unknown')}\n"
|
||||
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
|
||||
)
|
||||
|
||||
elif _current_mode == "record":
|
||||
# Make real request and record it
|
||||
response = await original_method(self, **kwargs)
|
||||
|
||||
# Store the recording
|
||||
request_data = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
|
|
@ -448,45 +295,31 @@ def patch_inference_clients():
|
|||
global _original_methods
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from openai import AsyncOpenAI
|
||||
# Also import Ollama AsyncClient
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
|
||||
# Also import Ollama AsyncClient
|
||||
try:
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
except ImportError:
|
||||
ollama_async_client = None
|
||||
else:
|
||||
ollama_async_client = OllamaAsyncClient
|
||||
|
||||
# Store original methods for both OpenAI and Ollama clients
|
||||
_original_methods = {
|
||||
"chat_completions_create": AsyncChatCompletions.create,
|
||||
"completions_create": AsyncCompletions.create,
|
||||
"embeddings_create": AsyncEmbeddings.create,
|
||||
"ollama_generate": OllamaAsyncClient.generate,
|
||||
"ollama_chat": OllamaAsyncClient.chat,
|
||||
"ollama_embed": OllamaAsyncClient.embed,
|
||||
}
|
||||
|
||||
# Add Ollama client methods if available
|
||||
if ollama_async_client:
|
||||
_original_methods.update(
|
||||
{
|
||||
"ollama_generate": ollama_async_client.generate,
|
||||
"ollama_chat": ollama_async_client.chat,
|
||||
"ollama_embed": ollama_async_client.embed,
|
||||
}
|
||||
)
|
||||
|
||||
# Create patched methods for OpenAI client
|
||||
async def patched_chat_completions_create(self, **kwargs):
|
||||
return await _patched_create_method(_original_methods["chat_completions_create"], self, **kwargs)
|
||||
return await _patched_inference_method(_original_methods["chat_completions_create"], self, "openai", **kwargs)
|
||||
|
||||
async def patched_completions_create(self, **kwargs):
|
||||
return await _patched_create_method(_original_methods["completions_create"], self, **kwargs)
|
||||
return await _patched_inference_method(_original_methods["completions_create"], self, "openai", **kwargs)
|
||||
|
||||
async def patched_embeddings_create(self, **kwargs):
|
||||
return await _patched_create_method(_original_methods["embeddings_create"], self, **kwargs)
|
||||
return await _patched_inference_method(_original_methods["embeddings_create"], self, "openai", **kwargs)
|
||||
|
||||
# Apply OpenAI patches
|
||||
AsyncChatCompletions.create = patched_chat_completions_create
|
||||
|
|
@ -494,40 +327,21 @@ def patch_inference_clients():
|
|||
AsyncEmbeddings.create = patched_embeddings_create
|
||||
|
||||
# Create patched methods for Ollama client
|
||||
if ollama_async_client:
|
||||
async def patched_ollama_generate(self, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["ollama_generate"], self, "ollama", "generate", **kwargs
|
||||
)
|
||||
|
||||
async def patched_ollama_generate(self, **kwargs):
|
||||
return await _patched_ollama_method(_original_methods["ollama_generate"], self, "generate", **kwargs)
|
||||
async def patched_ollama_chat(self, **kwargs):
|
||||
return await _patched_inference_method(_original_methods["ollama_chat"], self, "ollama", "chat", **kwargs)
|
||||
|
||||
async def patched_ollama_chat(self, **kwargs):
|
||||
return await _patched_ollama_method(_original_methods["ollama_chat"], self, "chat", **kwargs)
|
||||
async def patched_ollama_embed(self, **kwargs):
|
||||
return await _patched_inference_method(_original_methods["ollama_embed"], self, "ollama", "embed", **kwargs)
|
||||
|
||||
async def patched_ollama_embed(self, **kwargs):
|
||||
return await _patched_ollama_method(_original_methods["ollama_embed"], self, "embed", **kwargs)
|
||||
|
||||
# Apply Ollama patches
|
||||
ollama_async_client.generate = patched_ollama_generate
|
||||
ollama_async_client.chat = patched_ollama_chat
|
||||
ollama_async_client.embed = patched_ollama_embed
|
||||
|
||||
# Also try to patch the AsyncOpenAI __init__ to trace client creation
|
||||
original_openai_init = AsyncOpenAI.__init__
|
||||
|
||||
def patched_openai_init(self, *args, **kwargs):
|
||||
result = original_openai_init(self, *args, **kwargs)
|
||||
|
||||
# After client is created, try to re-patch its methods
|
||||
if hasattr(self, "chat") and hasattr(self.chat, "completions"):
|
||||
original_chat_create = self.chat.completions.create
|
||||
|
||||
async def instance_patched_chat_create(**kwargs):
|
||||
return await _patched_create_method(original_chat_create, self.chat.completions, **kwargs)
|
||||
|
||||
self.chat.completions.create = instance_patched_chat_create
|
||||
|
||||
return result
|
||||
|
||||
AsyncOpenAI.__init__ = patched_openai_init
|
||||
# Apply Ollama patches
|
||||
OllamaAsyncClient.generate = patched_ollama_generate
|
||||
OllamaAsyncClient.chat = patched_ollama_chat
|
||||
OllamaAsyncClient.embed = patched_ollama_embed
|
||||
|
||||
|
||||
def unpatch_inference_clients():
|
||||
|
|
@ -538,43 +352,26 @@ def unpatch_inference_clients():
|
|||
return
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
|
||||
# Restore OpenAI client methods
|
||||
if "chat_completions_create" in _original_methods:
|
||||
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||
|
||||
if "completions_create" in _original_methods:
|
||||
AsyncCompletions.create = _original_methods["completions_create"]
|
||||
|
||||
if "embeddings_create" in _original_methods:
|
||||
AsyncEmbeddings.create = _original_methods["embeddings_create"]
|
||||
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||
AsyncCompletions.create = _original_methods["completions_create"]
|
||||
AsyncEmbeddings.create = _original_methods["embeddings_create"]
|
||||
|
||||
# Restore Ollama client methods if they were patched
|
||||
try:
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
|
||||
if "ollama_generate" in _original_methods:
|
||||
OllamaAsyncClient.generate = _original_methods["ollama_generate"]
|
||||
|
||||
if "ollama_chat" in _original_methods:
|
||||
OllamaAsyncClient.chat = _original_methods["ollama_chat"]
|
||||
|
||||
if "ollama_embed" in _original_methods:
|
||||
OllamaAsyncClient.embed = _original_methods["ollama_embed"]
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
OllamaAsyncClient.generate = _original_methods["ollama_generate"]
|
||||
OllamaAsyncClient.chat = _original_methods["ollama_chat"]
|
||||
OllamaAsyncClient.embed = _original_methods["ollama_embed"]
|
||||
|
||||
_original_methods.clear()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def inference_recording(
|
||||
mode: str = "live", test_id: str | None = None, storage_dir: str | Path | None = None
|
||||
) -> Generator[None, None, None]:
|
||||
def inference_recording(mode: str = "live", storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||
"""Context manager for inference recording/replaying."""
|
||||
global _current_mode, _current_storage
|
||||
|
||||
|
|
@ -584,9 +381,6 @@ def inference_recording(
|
|||
else:
|
||||
storage_dir_path = Path(storage_dir)
|
||||
|
||||
if test_id is None:
|
||||
test_id = f"test_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Store previous state
|
||||
prev_mode = _current_mode
|
||||
prev_storage = _current_storage
|
||||
|
|
@ -595,7 +389,7 @@ def inference_recording(
|
|||
_current_mode = mode
|
||||
|
||||
if mode in ["record", "replay"]:
|
||||
_current_storage = ResponseStorage(storage_dir_path, test_id)
|
||||
_current_storage = ResponseStorage(storage_dir_path)
|
||||
patch_inference_clients()
|
||||
|
||||
yield
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue