mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 04:17:32 +00:00
fix tests
This commit is contained in:
parent
34f5da6cdc
commit
48f91f596b
2 changed files with 6 additions and 6 deletions
|
@ -445,7 +445,7 @@ def unpatch_inference_clients():
|
||||||
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||||
AsyncCompletions.create = _original_methods["completions_create"]
|
AsyncCompletions.create = _original_methods["completions_create"]
|
||||||
AsyncEmbeddings.create = _original_methods["embeddings_create"]
|
AsyncEmbeddings.create = _original_methods["embeddings_create"]
|
||||||
AsyncModels.list = _original_methods["openai_models_list"]
|
AsyncModels.list = _original_methods["models_list"]
|
||||||
|
|
||||||
# Restore Ollama client methods if they were patched
|
# Restore Ollama client methods if they were patched
|
||||||
OllamaAsyncClient.generate = _original_methods["ollama_generate"]
|
OllamaAsyncClient.generate = _original_methods["ollama_generate"]
|
||||||
|
@ -459,12 +459,10 @@ def unpatch_inference_clients():
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def inference_recording(mode: str, storage_dir: str | Path) -> Generator[None, None, None]:
|
def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||||
"""Context manager for inference recording/replaying."""
|
"""Context manager for inference recording/replaying."""
|
||||||
global _current_mode, _current_storage
|
global _current_mode, _current_storage
|
||||||
|
|
||||||
storage_dir_path = Path(storage_dir)
|
|
||||||
|
|
||||||
# Store previous state
|
# Store previous state
|
||||||
prev_mode = _current_mode
|
prev_mode = _current_mode
|
||||||
prev_storage = _current_storage
|
prev_storage = _current_storage
|
||||||
|
@ -473,7 +471,9 @@ def inference_recording(mode: str, storage_dir: str | Path) -> Generator[None, N
|
||||||
_current_mode = mode
|
_current_mode = mode
|
||||||
|
|
||||||
if mode in ["record", "replay"]:
|
if mode in ["record", "replay"]:
|
||||||
_current_storage = ResponseStorage(storage_dir_path)
|
if storage_dir is None:
|
||||||
|
raise ValueError("storage_dir is required for record and replay modes")
|
||||||
|
_current_storage = ResponseStorage(Path(storage_dir))
|
||||||
patch_inference_clients()
|
patch_inference_clients()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
|
@ -266,7 +266,7 @@ class TestInferenceRecording:
|
||||||
return real_openai_chat_response
|
return real_openai_chat_response
|
||||||
|
|
||||||
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
|
||||||
with inference_recording(mode=InferenceMode.LIVE):
|
with inference_recording(mode=InferenceMode.LIVE, storage_dir="foo"):
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue