mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 12:21:52 +00:00
feat(tests): make inference_recorder into api_recorder (include tool_invoke)
This commit is contained in:
parent
dd1f946b3e
commit
8da9ff352c
45 changed files with 25144 additions and 41 deletions
|
@ -19,7 +19,7 @@ from llama_stack.log import get_logger
|
|||
|
||||
logger = get_logger(__name__, category="testing")
|
||||
|
||||
# Global state for the recording system
|
||||
# Global state for the API recording system
|
||||
_current_mode: str | None = None
|
||||
_current_storage: ResponseStorage | None = None
|
||||
_original_methods: dict[str, Any] = {}
|
||||
|
@ -34,7 +34,7 @@ REPO_ROOT = Path(__file__).parent.parent.parent
|
|||
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/recordings"
|
||||
|
||||
|
||||
class InferenceMode(StrEnum):
|
||||
class APIRecordingMode(StrEnum):
|
||||
LIVE = "live"
|
||||
RECORD = "record"
|
||||
REPLAY = "replay"
|
||||
|
@ -53,16 +53,27 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
|
|||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||
|
||||
|
||||
def get_inference_mode() -> InferenceMode:
|
||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
|
||||
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
|
||||
"""Create a normalized hash of the tool request for consistent matching."""
|
||||
normalized = {"provider": provider_name, "tool_name": tool_name, "kwargs": kwargs}
|
||||
|
||||
# Create hash - sort_keys=True ensures deterministic ordering
|
||||
normalized_json = json.dumps(normalized, sort_keys=True)
|
||||
return hashlib.sha256(normalized_json.encode()).hexdigest()
|
||||
|
||||
|
||||
def setup_inference_recording():
|
||||
def get_api_recording_mode() -> APIRecordingMode:
|
||||
return APIRecordingMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
|
||||
|
||||
|
||||
def setup_api_recording():
|
||||
"""
|
||||
Returns a context manager that can be used to record or replay inference requests. This is to be used in tests
|
||||
to increase their reliability and reduce reliance on expensive, external services.
|
||||
Returns a context manager that can be used to record or replay API requests (inference and tools).
|
||||
This is to be used in tests to increase their reliability and reduce reliance on expensive, external services.
|
||||
|
||||
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
||||
Currently supports:
|
||||
- Inference: OpenAI, Ollama, and LiteLLM clients
|
||||
- Tools: Search providers (Tavily for now)
|
||||
|
||||
Two environment variables are supported:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'. Default is 'replay'.
|
||||
|
@ -70,12 +81,12 @@ def setup_inference_recording():
|
|||
|
||||
The recordings are stored as JSON files.
|
||||
"""
|
||||
mode = get_inference_mode()
|
||||
if mode == InferenceMode.LIVE:
|
||||
mode = get_api_recording_mode()
|
||||
if mode == APIRecordingMode.LIVE:
|
||||
return None
|
||||
|
||||
storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR)
|
||||
return inference_recording(mode=mode, storage_dir=storage_dir)
|
||||
return api_recording(mode=mode, storage_dir=storage_dir)
|
||||
|
||||
|
||||
def _serialize_response(response: Any) -> Any:
|
||||
|
@ -112,7 +123,7 @@ def _deserialize_response(data: dict[str, Any]) -> Any:
|
|||
|
||||
|
||||
class ResponseStorage:
|
||||
"""Handles SQLite index + JSON file storage/retrieval for inference recordings."""
|
||||
"""Handles storage/retrieval for API recordings (inference and tools)."""
|
||||
|
||||
def __init__(self, test_dir: Path):
|
||||
self.test_dir = test_dir
|
||||
|
@ -243,12 +254,55 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
|
|||
return {"request": canonical_req, "response": {"body": body, "is_streaming": False}}
|
||||
|
||||
|
||||
async def _patched_tool_invoke_method(
|
||||
original_method, provider_name: str, self, tool_name: str, kwargs: dict[str, Any]
|
||||
):
|
||||
"""Patched version of tool runtime invoke_tool method for recording/replay."""
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
|
||||
# Normal operation
|
||||
return await original_method(self, tool_name, kwargs)
|
||||
|
||||
request_hash = normalize_tool_request(provider_name, tool_name, kwargs)
|
||||
|
||||
if _current_mode == APIRecordingMode.REPLAY:
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
if recording:
|
||||
return recording["response"]["body"]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"No recorded tool result found for {provider_name}.{tool_name}\n"
|
||||
f"Request: {kwargs}\n"
|
||||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||
)
|
||||
|
||||
elif _current_mode == APIRecordingMode.RECORD:
|
||||
result = await original_method(self, tool_name, kwargs)
|
||||
|
||||
request_data = {
|
||||
"provider": provider_name,
|
||||
"tool_name": tool_name,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
response_data = {"body": result, "is_streaming": False}
|
||||
|
||||
_current_storage.store_recording(request_hash, request_data, response_data)
|
||||
return result
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Invalid mode: {_current_mode}")
|
||||
|
||||
|
||||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
||||
if _current_mode == APIRecordingMode.LIVE or _current_storage is None:
|
||||
# Normal operation
|
||||
return await original_method(self, *args, **kwargs)
|
||||
if client_type == "litellm":
|
||||
return await original_method(*args, **kwargs)
|
||||
else:
|
||||
return await original_method(self, *args, **kwargs)
|
||||
|
||||
# Get base URL based on client type
|
||||
if client_type == "openai":
|
||||
|
@ -258,6 +312,9 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
base_url = getattr(self, "host", "http://localhost:11434")
|
||||
if not base_url.startswith("http"):
|
||||
base_url = f"http://{base_url}"
|
||||
elif client_type == "litellm":
|
||||
# For LiteLLM, extract base URL from kwargs if available
|
||||
base_url = kwargs.get("api_base", "https://api.openai.com")
|
||||
else:
|
||||
raise ValueError(f"Unknown client type: {client_type}")
|
||||
|
||||
|
@ -268,7 +325,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
|
||||
request_hash = normalize_request(method, url, headers, body)
|
||||
|
||||
if _current_mode == InferenceMode.REPLAY:
|
||||
if _current_mode == APIRecordingMode.REPLAY:
|
||||
# Special handling for model-list endpoints: return union of all responses
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
records = _current_storage._model_list_responses(request_hash[:12])
|
||||
|
@ -295,8 +352,11 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
f"To record this response, run with LLAMA_STACK_TEST_INFERENCE_MODE=record"
|
||||
)
|
||||
|
||||
elif _current_mode == InferenceMode.RECORD:
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
elif _current_mode == APIRecordingMode.RECORD:
|
||||
if client_type == "litellm":
|
||||
response = await original_method(*args, **kwargs)
|
||||
else:
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
|
||||
request_data = {
|
||||
"method": method,
|
||||
|
@ -336,17 +396,21 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
raise AssertionError(f"Invalid mode: {_current_mode}")
|
||||
|
||||
|
||||
def patch_inference_clients():
|
||||
"""Install monkey patches for OpenAI client methods and Ollama AsyncClient methods."""
|
||||
def patch_api_clients():
|
||||
"""Install monkey patches for inference clients and tool runtime methods."""
|
||||
global _original_methods
|
||||
|
||||
import litellm
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
from openai.resources.models import AsyncModels
|
||||
|
||||
# Store original methods for both OpenAI and Ollama clients
|
||||
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
# Store original methods for OpenAI, Ollama, LiteLLM clients, and tool runtimes
|
||||
_original_methods = {
|
||||
"chat_completions_create": AsyncChatCompletions.create,
|
||||
"completions_create": AsyncCompletions.create,
|
||||
|
@ -358,6 +422,10 @@ def patch_inference_clients():
|
|||
"ollama_ps": OllamaAsyncClient.ps,
|
||||
"ollama_pull": OllamaAsyncClient.pull,
|
||||
"ollama_list": OllamaAsyncClient.list,
|
||||
"litellm_acompletion": litellm.acompletion,
|
||||
"litellm_atext_completion": litellm.atext_completion,
|
||||
"litellm_openai_mixin_get_api_key": LiteLLMOpenAIMixin.get_api_key,
|
||||
"tavily_invoke_tool": TavilySearchToolRuntimeImpl.invoke_tool,
|
||||
}
|
||||
|
||||
# Create patched methods for OpenAI client
|
||||
|
@ -426,21 +494,61 @@ def patch_inference_clients():
|
|||
OllamaAsyncClient.pull = patched_ollama_pull
|
||||
OllamaAsyncClient.list = patched_ollama_list
|
||||
|
||||
# Create patched methods for LiteLLM
|
||||
async def patched_litellm_acompletion(*args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["litellm_acompletion"], None, "litellm", "/chat/completions", *args, **kwargs
|
||||
)
|
||||
|
||||
def unpatch_inference_clients():
|
||||
"""Remove monkey patches and restore original OpenAI and Ollama client methods."""
|
||||
async def patched_litellm_atext_completion(*args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["litellm_atext_completion"], None, "litellm", "/completions", *args, **kwargs
|
||||
)
|
||||
|
||||
# Apply LiteLLM patches
|
||||
litellm.acompletion = patched_litellm_acompletion
|
||||
litellm.atext_completion = patched_litellm_atext_completion
|
||||
|
||||
# Create patched method for LiteLLMOpenAIMixin.get_api_key
|
||||
def patched_litellm_get_api_key(self):
|
||||
global _current_mode
|
||||
if _current_mode != APIRecordingMode.REPLAY:
|
||||
return _original_methods["litellm_openai_mixin_get_api_key"](self)
|
||||
else:
|
||||
# For record/replay modes, return a fake API key to avoid exposing real credentials
|
||||
return "fake-api-key-for-testing"
|
||||
|
||||
# Apply LiteLLMOpenAIMixin patch
|
||||
LiteLLMOpenAIMixin.get_api_key = patched_litellm_get_api_key
|
||||
|
||||
# Create patched methods for tool runtimes
|
||||
async def patched_tavily_invoke_tool(self, tool_name: str, kwargs: dict[str, Any]):
|
||||
return await _patched_tool_invoke_method(
|
||||
_original_methods["tavily_invoke_tool"], "tavily", self, tool_name, kwargs
|
||||
)
|
||||
|
||||
# Apply tool runtime patches
|
||||
TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool
|
||||
|
||||
|
||||
def unpatch_api_clients():
|
||||
"""Remove monkey patches and restore original client methods and tool runtimes."""
|
||||
global _original_methods
|
||||
|
||||
if not _original_methods:
|
||||
return
|
||||
|
||||
# Import here to avoid circular imports
|
||||
import litellm
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
from openai.resources.models import AsyncModels
|
||||
|
||||
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
# Restore OpenAI client methods
|
||||
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||
AsyncCompletions.create = _original_methods["completions_create"]
|
||||
|
@ -455,12 +563,20 @@ def unpatch_inference_clients():
|
|||
OllamaAsyncClient.pull = _original_methods["ollama_pull"]
|
||||
OllamaAsyncClient.list = _original_methods["ollama_list"]
|
||||
|
||||
# Restore LiteLLM methods
|
||||
litellm.acompletion = _original_methods["litellm_acompletion"]
|
||||
litellm.atext_completion = _original_methods["litellm_atext_completion"]
|
||||
LiteLLMOpenAIMixin.get_api_key = _original_methods["litellm_openai_mixin_get_api_key"]
|
||||
|
||||
# Restore tool runtime methods
|
||||
TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"]
|
||||
|
||||
_original_methods.clear()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||
"""Context manager for inference recording/replaying."""
|
||||
def api_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||
"""Context manager for API recording/replaying (inference and tools)."""
|
||||
global _current_mode, _current_storage
|
||||
|
||||
# Store previous state
|
||||
|
@ -474,14 +590,14 @@ def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Gen
|
|||
if storage_dir is None:
|
||||
raise ValueError("storage_dir is required for record and replay modes")
|
||||
_current_storage = ResponseStorage(Path(storage_dir))
|
||||
patch_inference_clients()
|
||||
patch_api_clients()
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# Restore previous state
|
||||
if mode in ["record", "replay"]:
|
||||
unpatch_inference_clients()
|
||||
unpatch_api_clients()
|
||||
|
||||
_current_mode = prev_mode
|
||||
_current_storage = prev_storage
|
Loading…
Add table
Add a link
Reference in a new issue