improvements to include other ollama methods like ps, list and pull

This commit is contained in:
Ashwin Bharambe 2025-07-29 10:56:57 -07:00
parent 1a21c4b695
commit b578f9aec1
23 changed files with 6748 additions and 5711 deletions

View file

@ -13,13 +13,23 @@ import sqlite3
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Any, cast
from typing import Any, Literal, cast
from llama_stack.log import get_logger
logger = get_logger(__name__, category="testing")
# Global state for the recording system
_current_mode: str | None = None
_current_storage: ResponseStorage | None = None
_original_methods: dict[str, Any] = {}
from openai.types.completion_choice import CompletionChoice
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None
CompletionChoice.model_rebuild()
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
"""Create a normalized hash of the request for consistent matching."""
@ -35,14 +45,14 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
def get_inference_mode() -> str:
return os.environ.get("LLAMA_STACK_INFERENCE_MODE", "live").lower()
return os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
def setup_inference_recording():
mode = get_inference_mode()
if mode not in ["live", "record", "replay"]:
raise ValueError(f"Invalid LLAMA_STACK_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
if mode == "live":
# Return a no-op context manager for live mode
@ -52,23 +62,43 @@ def setup_inference_recording():
return live_mode()
if "LLAMA_STACK_RECORDING_DIR" not in os.environ:
raise ValueError("LLAMA_STACK_RECORDING_DIR must be set for recording or replaying")
storage_dir = os.environ["LLAMA_STACK_RECORDING_DIR"]
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
storage_dir = os.environ["LLAMA_STACK_TEST_RECORDING_DIR"]
return inference_recording(mode=mode, storage_dir=storage_dir)
def _serialize_response(response: Any) -> Any:
if hasattr(response, "model_dump"):
return response.model_dump()
data = response.model_dump(mode="json")
return {
"__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}",
"__data__": data,
}
elif hasattr(response, "__dict__"):
return dict(response.__dict__)
else:
return response
def _deserialize_response(data: dict[str, Any]) -> dict[str, Any]:
def _deserialize_response(data: dict[str, Any]) -> Any:
# Check if this is a serialized Pydantic model with type information
if isinstance(data, dict) and "__type__" in data and "__data__" in data:
try:
# Import the original class and reconstruct the object
module_path, class_name = data["__type__"].rsplit(".", 1)
module = __import__(module_path, fromlist=[class_name])
cls = getattr(module, class_name)
if not hasattr(cls, "model_validate"):
raise ValueError(f"Pydantic class {cls} does not support model_validate?")
return cls.model_validate(data["__data__"])
except (ImportError, AttributeError, TypeError, ValueError) as e:
logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}")
return data["__data__"]
return data
@ -120,6 +150,7 @@ class ResponseStorage:
with open(response_path, "w") as f:
json.dump({"request": request, "response": serialized_response}, f, indent=2)
f.write("\n")
f.flush()
# Update SQLite index
with sqlite3.connect(self.db_path) as conn:
@ -214,6 +245,8 @@ async def _patched_inference_method(original_method, self, client_type, method_n
endpoint = "/api/chat"
elif method_name == "embed":
endpoint = "/api/embeddings"
elif method_name == "list":
endpoint = "/api/tags"
else:
endpoint = f"/api/{method_name}"
else:
@ -295,8 +328,6 @@ def patch_inference_clients():
"""Install monkey patches for OpenAI client methods and Ollama AsyncClient methods."""
global _original_methods
# Import here to avoid circular imports
# Also import Ollama AsyncClient
from ollama import AsyncClient as OllamaAsyncClient
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
from openai.resources.completions import AsyncCompletions
@ -310,6 +341,9 @@ def patch_inference_clients():
"ollama_generate": OllamaAsyncClient.generate,
"ollama_chat": OllamaAsyncClient.chat,
"ollama_embed": OllamaAsyncClient.embed,
"ollama_ps": OllamaAsyncClient.ps,
"ollama_pull": OllamaAsyncClient.pull,
"ollama_list": OllamaAsyncClient.list,
}
# Create patched methods for OpenAI client
@ -339,10 +373,24 @@ def patch_inference_clients():
async def patched_ollama_embed(self, **kwargs):
return await _patched_inference_method(_original_methods["ollama_embed"], self, "ollama", "embed", **kwargs)
async def patched_ollama_ps(self, **kwargs):
logger.info("replay mode: ollama.ps() reporting success")
return []
async def patched_ollama_pull(self, *args, **kwargs):
logger.info("replay mode: ollama.pull() not actually pulling the model")
return None
async def patched_ollama_list(self, **kwargs):
return await _patched_inference_method(_original_methods["ollama_list"], self, "ollama", "list", **kwargs)
# Apply Ollama patches
OllamaAsyncClient.generate = patched_ollama_generate
OllamaAsyncClient.chat = patched_ollama_chat
OllamaAsyncClient.embed = patched_ollama_embed
OllamaAsyncClient.ps = patched_ollama_ps
OllamaAsyncClient.pull = patched_ollama_pull
OllamaAsyncClient.list = patched_ollama_list
def unpatch_inference_clients():
@ -367,6 +415,8 @@ def unpatch_inference_clients():
OllamaAsyncClient.generate = _original_methods["ollama_generate"]
OllamaAsyncClient.chat = _original_methods["ollama_chat"]
OllamaAsyncClient.embed = _original_methods["ollama_embed"]
OllamaAsyncClient.ps = _original_methods["ollama_ps"]
OllamaAsyncClient.list = _original_methods["ollama_list"]
_original_methods.clear()