fix tests

This commit is contained in:
Ashwin Bharambe 2025-09-03 11:27:21 -07:00
parent 34f5da6cdc
commit 48f91f596b
2 changed files with 6 additions and 6 deletions

View file

@ -445,7 +445,7 @@ def unpatch_inference_clients():
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
AsyncCompletions.create = _original_methods["completions_create"]
AsyncEmbeddings.create = _original_methods["embeddings_create"]
AsyncModels.list = _original_methods["openai_models_list"]
AsyncModels.list = _original_methods["models_list"]
# Restore Ollama client methods if they were patched
OllamaAsyncClient.generate = _original_methods["ollama_generate"]
@ -459,12 +459,10 @@ def unpatch_inference_clients():
@contextmanager
def inference_recording(mode: str, storage_dir: str | Path) -> Generator[None, None, None]:
def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
"""Context manager for inference recording/replaying."""
global _current_mode, _current_storage
storage_dir_path = Path(storage_dir)
# Store previous state
prev_mode = _current_mode
prev_storage = _current_storage
@ -473,7 +471,9 @@ def inference_recording(mode: str, storage_dir: str | Path) -> Generator[None, N
_current_mode = mode
if mode in ["record", "replay"]:
_current_storage = ResponseStorage(storage_dir_path)
if storage_dir is None:
raise ValueError("storage_dir is required for record and replay modes")
_current_storage = ResponseStorage(Path(storage_dir))
patch_inference_clients()
yield

View file

@ -266,7 +266,7 @@ class TestInferenceRecording:
return real_openai_chat_response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.LIVE):
with inference_recording(mode=InferenceMode.LIVE, storage_dir="foo"):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(