mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 04:12:25 +00:00
improvements to include other ollama methods like ps, list and pull
This commit is contained in:
parent
1a21c4b695
commit
b578f9aec1
23 changed files with 6748 additions and 5711 deletions
|
|
@ -13,13 +13,23 @@ import sqlite3
|
|||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="testing")
|
||||
|
||||
# Global state for the recording system
|
||||
_current_mode: str | None = None
|
||||
_current_storage: ResponseStorage | None = None
|
||||
_original_methods: dict[str, Any] = {}
|
||||
|
||||
from openai.types.completion_choice import CompletionChoice
|
||||
|
||||
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
|
||||
CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None
|
||||
CompletionChoice.model_rebuild()
|
||||
|
||||
|
||||
def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str:
|
||||
"""Create a normalized hash of the request for consistent matching."""
|
||||
|
|
@ -35,14 +45,14 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
|
|||
|
||||
|
||||
def get_inference_mode() -> str:
|
||||
return os.environ.get("LLAMA_STACK_INFERENCE_MODE", "live").lower()
|
||||
return os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()
|
||||
|
||||
|
||||
def setup_inference_recording():
|
||||
mode = get_inference_mode()
|
||||
|
||||
if mode not in ["live", "record", "replay"]:
|
||||
raise ValueError(f"Invalid LLAMA_STACK_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
|
||||
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
|
||||
|
||||
if mode == "live":
|
||||
# Return a no-op context manager for live mode
|
||||
|
|
@ -52,23 +62,43 @@ def setup_inference_recording():
|
|||
|
||||
return live_mode()
|
||||
|
||||
if "LLAMA_STACK_RECORDING_DIR" not in os.environ:
|
||||
raise ValueError("LLAMA_STACK_RECORDING_DIR must be set for recording or replaying")
|
||||
storage_dir = os.environ["LLAMA_STACK_RECORDING_DIR"]
|
||||
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
|
||||
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
|
||||
storage_dir = os.environ["LLAMA_STACK_TEST_RECORDING_DIR"]
|
||||
|
||||
return inference_recording(mode=mode, storage_dir=storage_dir)
|
||||
|
||||
|
||||
def _serialize_response(response: Any) -> Any:
|
||||
if hasattr(response, "model_dump"):
|
||||
return response.model_dump()
|
||||
data = response.model_dump(mode="json")
|
||||
return {
|
||||
"__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}",
|
||||
"__data__": data,
|
||||
}
|
||||
elif hasattr(response, "__dict__"):
|
||||
return dict(response.__dict__)
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
def _deserialize_response(data: dict[str, Any]) -> dict[str, Any]:
|
||||
def _deserialize_response(data: dict[str, Any]) -> Any:
|
||||
# Check if this is a serialized Pydantic model with type information
|
||||
if isinstance(data, dict) and "__type__" in data and "__data__" in data:
|
||||
try:
|
||||
# Import the original class and reconstruct the object
|
||||
module_path, class_name = data["__type__"].rsplit(".", 1)
|
||||
module = __import__(module_path, fromlist=[class_name])
|
||||
cls = getattr(module, class_name)
|
||||
|
||||
if not hasattr(cls, "model_validate"):
|
||||
raise ValueError(f"Pydantic class {cls} does not support model_validate?")
|
||||
|
||||
return cls.model_validate(data["__data__"])
|
||||
except (ImportError, AttributeError, TypeError, ValueError) as e:
|
||||
logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}")
|
||||
return data["__data__"]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
|
@ -120,6 +150,7 @@ class ResponseStorage:
|
|||
with open(response_path, "w") as f:
|
||||
json.dump({"request": request, "response": serialized_response}, f, indent=2)
|
||||
f.write("\n")
|
||||
f.flush()
|
||||
|
||||
# Update SQLite index
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
|
|
@ -214,6 +245,8 @@ async def _patched_inference_method(original_method, self, client_type, method_n
|
|||
endpoint = "/api/chat"
|
||||
elif method_name == "embed":
|
||||
endpoint = "/api/embeddings"
|
||||
elif method_name == "list":
|
||||
endpoint = "/api/tags"
|
||||
else:
|
||||
endpoint = f"/api/{method_name}"
|
||||
else:
|
||||
|
|
@ -295,8 +328,6 @@ def patch_inference_clients():
|
|||
"""Install monkey patches for OpenAI client methods and Ollama AsyncClient methods."""
|
||||
global _original_methods
|
||||
|
||||
# Import here to avoid circular imports
|
||||
# Also import Ollama AsyncClient
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
|
|
@ -310,6 +341,9 @@ def patch_inference_clients():
|
|||
"ollama_generate": OllamaAsyncClient.generate,
|
||||
"ollama_chat": OllamaAsyncClient.chat,
|
||||
"ollama_embed": OllamaAsyncClient.embed,
|
||||
"ollama_ps": OllamaAsyncClient.ps,
|
||||
"ollama_pull": OllamaAsyncClient.pull,
|
||||
"ollama_list": OllamaAsyncClient.list,
|
||||
}
|
||||
|
||||
# Create patched methods for OpenAI client
|
||||
|
|
@ -339,10 +373,24 @@ def patch_inference_clients():
|
|||
async def patched_ollama_embed(self, **kwargs):
|
||||
return await _patched_inference_method(_original_methods["ollama_embed"], self, "ollama", "embed", **kwargs)
|
||||
|
||||
async def patched_ollama_ps(self, **kwargs):
|
||||
logger.info("replay mode: ollama.ps() reporting success")
|
||||
return []
|
||||
|
||||
async def patched_ollama_pull(self, *args, **kwargs):
|
||||
logger.info("replay mode: ollama.pull() not actually pulling the model")
|
||||
return None
|
||||
|
||||
async def patched_ollama_list(self, **kwargs):
|
||||
return await _patched_inference_method(_original_methods["ollama_list"], self, "ollama", "list", **kwargs)
|
||||
|
||||
# Apply Ollama patches
|
||||
OllamaAsyncClient.generate = patched_ollama_generate
|
||||
OllamaAsyncClient.chat = patched_ollama_chat
|
||||
OllamaAsyncClient.embed = patched_ollama_embed
|
||||
OllamaAsyncClient.ps = patched_ollama_ps
|
||||
OllamaAsyncClient.pull = patched_ollama_pull
|
||||
OllamaAsyncClient.list = patched_ollama_list
|
||||
|
||||
|
||||
def unpatch_inference_clients():
|
||||
|
|
@ -367,6 +415,8 @@ def unpatch_inference_clients():
|
|||
OllamaAsyncClient.generate = _original_methods["ollama_generate"]
|
||||
OllamaAsyncClient.chat = _original_methods["ollama_chat"]
|
||||
OllamaAsyncClient.embed = _original_methods["ollama_embed"]
|
||||
OllamaAsyncClient.ps = _original_methods["ollama_ps"]
|
||||
OllamaAsyncClient.list = _original_methods["ollama_list"]
|
||||
|
||||
_original_methods.clear()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue