From 48f91f596bec52ea4b2f4eee2602dca4cffd18dd Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 3 Sep 2025 11:27:21 -0700 Subject: [PATCH] fix tests --- llama_stack/testing/inference_recorder.py | 10 +++++----- tests/unit/distribution/test_inference_recordings.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index a34445ffb..5b64e26d3 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -445,7 +445,7 @@ def unpatch_inference_clients(): AsyncChatCompletions.create = _original_methods["chat_completions_create"] AsyncCompletions.create = _original_methods["completions_create"] AsyncEmbeddings.create = _original_methods["embeddings_create"] - AsyncModels.list = _original_methods["openai_models_list"] + AsyncModels.list = _original_methods["models_list"] # Restore Ollama client methods if they were patched OllamaAsyncClient.generate = _original_methods["ollama_generate"] @@ -459,12 +459,10 @@ def unpatch_inference_clients(): @contextmanager -def inference_recording(mode: str, storage_dir: str | Path) -> Generator[None, None, None]: +def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]: """Context manager for inference recording/replaying.""" global _current_mode, _current_storage - storage_dir_path = Path(storage_dir) - # Store previous state prev_mode = _current_mode prev_storage = _current_storage @@ -473,7 +471,9 @@ def inference_recording(mode: str, storage_dir: str | Path) -> Generator[None, N _current_mode = mode if mode in ["record", "replay"]: - _current_storage = ResponseStorage(storage_dir_path) + if storage_dir is None: + raise ValueError("storage_dir is required for record and replay modes") + _current_storage = ResponseStorage(Path(storage_dir)) patch_inference_clients() yield diff --git a/tests/unit/distribution/test_inference_recordings.py b/tests/unit/distribution/test_inference_recordings.py index dd80b0caf..c69cf319b 100644 --- a/tests/unit/distribution/test_inference_recordings.py +++ b/tests/unit/distribution/test_inference_recordings.py @@ -266,7 +266,7 @@ class TestInferenceRecording: return real_openai_chat_response with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create): - with inference_recording(mode=InferenceMode.LIVE): + with inference_recording(mode=InferenceMode.LIVE, storage_dir="foo"): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") response = await client.chat.completions.create(