mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
fix tests
This commit is contained in:
parent
34f5da6cdc
commit
48f91f596b
2 changed files with 6 additions and 6 deletions
|
@ -445,7 +445,7 @@ def unpatch_inference_clients():
|
|||
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||
AsyncCompletions.create = _original_methods["completions_create"]
|
||||
AsyncEmbeddings.create = _original_methods["embeddings_create"]
|
||||
AsyncModels.list = _original_methods["openai_models_list"]
|
||||
AsyncModels.list = _original_methods["models_list"]
|
||||
|
||||
# Restore Ollama client methods if they were patched
|
||||
OllamaAsyncClient.generate = _original_methods["ollama_generate"]
|
||||
|
@ -459,12 +459,10 @@ def unpatch_inference_clients():
|
|||
|
||||
|
||||
@contextmanager
|
||||
def inference_recording(mode: str, storage_dir: str | Path) -> Generator[None, None, None]:
|
||||
def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||
"""Context manager for inference recording/replaying."""
|
||||
global _current_mode, _current_storage
|
||||
|
||||
storage_dir_path = Path(storage_dir)
|
||||
|
||||
# Store previous state
|
||||
prev_mode = _current_mode
|
||||
prev_storage = _current_storage
|
||||
|
@ -473,7 +471,9 @@ def inference_recording(mode: str, storage_dir: str | Path) -> Generator[None, N
|
|||
_current_mode = mode
|
||||
|
||||
if mode in ["record", "replay"]:
|
||||
_current_storage = ResponseStorage(storage_dir_path)
|
||||
if storage_dir is None:
|
||||
raise ValueError("storage_dir is required for record and replay modes")
|
||||
_current_storage = ResponseStorage(Path(storage_dir))
|
||||
patch_inference_clients()
|
||||
|
||||
yield
|
||||
|
|
|
@ -266,7 +266,7 @@ class TestInferenceRecording:
|
|||
return real_openai_chat_response
|
||||
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||
with inference_recording(mode=InferenceMode.LIVE):
|
||||
with inference_recording(mode=InferenceMode.LIVE, storage_dir="foo"):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue