chore(recorder): update mocks to be closer to non-mock environment (#3442)

# What does this PR do?

the @required_args decorator in openai-python is masking the async
nature of the {AsyncCompletions,chat.AsyncCompletions}.create method.
see https://github.com/openai/openai-python/issues/996

this means two things -

 0. we cannot use iscoroutine in the recorder to detect async vs non
 1. our mocks are inappropriately introducing identifiable async

for (0), we update the iscoroutine check w/ detection of /v1/models,
which is the only non-async function we mock & record.

for (1), we could leave everything as is and assume (0) will catch
errors. to be defensive, we update the unit tests to mock below create
methods, allowing the true openai-python create() methods to be tested.
This commit is contained in:
Matthew Farrellee 2025-09-15 15:25:53 -04:00 committed by GitHub
parent b6cb817897
commit 01bdcce4d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 113 additions and 109 deletions

View file

@ -7,7 +7,6 @@
from __future__ import annotations # for forward references from __future__ import annotations # for forward references
import hashlib import hashlib
import inspect
import json import json
import os import os
from collections.abc import Generator from collections.abc import Generator
@ -243,11 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
global _current_mode, _current_storage global _current_mode, _current_storage
if _current_mode == InferenceMode.LIVE or _current_storage is None: if _current_mode == InferenceMode.LIVE or _current_storage is None:
# Normal operation if endpoint == "/v1/models":
if inspect.iscoroutinefunction(original_method):
return await original_method(self, *args, **kwargs)
else:
return original_method(self, *args, **kwargs) return original_method(self, *args, **kwargs)
else:
return await original_method(self, *args, **kwargs)
# Get base URL based on client type # Get base URL based on client type
if client_type == "openai": if client_type == "openai":
@ -298,10 +296,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
) )
elif _current_mode == InferenceMode.RECORD: elif _current_mode == InferenceMode.RECORD:
if inspect.iscoroutinefunction(original_method): if endpoint == "/v1/models":
response = await original_method(self, *args, **kwargs)
else:
response = original_method(self, *args, **kwargs) response = original_method(self, *args, **kwargs)
else:
response = await original_method(self, *args, **kwargs)
# we want to store the result of the iterator, not the iterator itself # we want to store the result of the iterator, not the iterator itself
if endpoint == "/v1/models": if endpoint == "/v1/models":

View file

@ -155,27 +155,22 @@ class TestInferenceRecording:
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response): async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that recording mode captures and stores responses.""" """Test that recording mode captures and stores responses."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
temp_storage_dir = temp_storage_dir / "test_recording_mode" temp_storage_dir = temp_storage_dir / "test_recording_mode"
with patch( with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
"openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
): client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create( response = await client.chat.completions.create(
model="llama3.2:3b", model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}], messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7, temperature=0.7,
max_tokens=50, max_tokens=50,
user=NOT_GIVEN, user=NOT_GIVEN,
) )
# Verify the response was returned correctly # Verify the response was returned correctly
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking." assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
client.chat.completions._post.assert_called_once()
# Verify recording was stored # Verify recording was stored
storage = ResponseStorage(temp_storage_dir) storage = ResponseStorage(temp_storage_dir)
@ -183,43 +178,38 @@ class TestInferenceRecording:
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response): async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that replay mode returns stored responses without making real calls.""" """Test that replay mode returns stored responses without making real calls."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
temp_storage_dir = temp_storage_dir / "test_replay_mode" temp_storage_dir = temp_storage_dir / "test_replay_mode"
# First, record a response # First, record a response
with patch( with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
"openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
): client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create( response = await client.chat.completions.create(
model="llama3.2:3b", model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}], messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7, temperature=0.7,
max_tokens=50, max_tokens=50,
user=NOT_GIVEN, user=NOT_GIVEN,
) )
client.chat.completions._post.assert_called_once()
# Now test replay mode - should not call the original method # Now test replay mode - should not call the original method
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch: with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
response = await client.chat.completions.create( response = await client.chat.completions.create(
model="llama3.2:3b", model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}], messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7, temperature=0.7,
max_tokens=50, max_tokens=50,
) )
# Verify we got the recorded response # Verify we got the recorded response
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking." assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify the original method was NOT called # Verify the original method was NOT called
mock_create_patch.assert_not_called() client.chat.completions._post.assert_not_called()
async def test_replay_mode_models(self, temp_storage_dir): async def test_replay_mode_models(self, temp_storage_dir):
"""Test that replay mode returns stored responses without making real model listing calls.""" """Test that replay mode returns stored responses without making real model listing calls."""
@ -272,43 +262,50 @@ class TestInferenceRecording:
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response): async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
"""Test recording and replay of embeddings calls.""" """Test recording and replay of embeddings calls."""
async def mock_create(*args, **kwargs): # baseline - mock works without recording
return real_embeddings_response client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
encoding_format=NOT_GIVEN,
)
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
client.embeddings._post.assert_called_once()
temp_storage_dir = temp_storage_dir / "test_embeddings_recording" temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
# Record # Record
with patch( with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
"openai.resources.embeddings.AsyncEmbeddings.create", new_callable=AsyncMock, side_effect=mock_create client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
): client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.embeddings.create( response = await client.embeddings.create(
model=real_embeddings_response.model, model=real_embeddings_response.model,
input=["Hello world", "Test embedding"], input=["Hello world", "Test embedding"],
encoding_format=NOT_GIVEN, encoding_format=NOT_GIVEN,
dimensions=NOT_GIVEN, dimensions=NOT_GIVEN,
user=NOT_GIVEN, user=NOT_GIVEN,
) )
assert len(response.data) == 2 assert len(response.data) == 2
# Replay # Replay
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch: with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
response = await client.embeddings.create( response = await client.embeddings.create(
model=real_embeddings_response.model, model=real_embeddings_response.model,
input=["Hello world", "Test embedding"], input=["Hello world", "Test embedding"],
) )
# Verify we got the recorded response # Verify we got the recorded response
assert len(response.data) == 2 assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3] assert response.data[0].embedding == [0.1, 0.2, 0.3]
# Verify original method was not called # Verify original method was not called
mock_create_patch.assert_not_called() client.embeddings._post.assert_not_called()
async def test_completions_recording(self, temp_storage_dir): async def test_completions_recording(self, temp_storage_dir):
real_completions_response = OpenAICompletion( real_completions_response = OpenAICompletion(
@ -326,40 +323,49 @@ class TestInferenceRecording:
], ],
) )
async def mock_create(*args, **kwargs):
return real_completions_response
temp_storage_dir = temp_storage_dir / "test_completions_recording" temp_storage_dir = temp_storage_dir / "test_completions_recording"
# baseline - mock works without recording
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.completions._post = AsyncMock(return_value=real_completions_response)
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_called_once()
# Record # Record
with patch( with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
"openai.resources.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
): client.completions._post = AsyncMock(return_value=real_completions_response)
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.completions.create( response = await client.completions.create(
model=real_completions_response.model, model=real_completions_response.model,
prompt="Hello, how are you?", prompt="Hello, how are you?",
temperature=0.7, temperature=0.7,
max_tokens=50, max_tokens=50,
user=NOT_GIVEN, user=NOT_GIVEN,
) )
assert response.choices[0].text == real_completions_response.choices[0].text assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_called_once()
# Replay # Replay
with patch("openai.resources.completions.AsyncCompletions.create") as mock_create_patch: with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") client.completions._post = AsyncMock(return_value=real_completions_response)
response = await client.completions.create( response = await client.completions.create(
model=real_completions_response.model, model=real_completions_response.model,
prompt="Hello, how are you?", prompt="Hello, how are you?",
temperature=0.7, temperature=0.7,
max_tokens=50, max_tokens=50,
) )
assert response.choices[0].text == real_completions_response.choices[0].text assert response.choices[0].text == real_completions_response.choices[0].text
mock_create_patch.assert_not_called() client.completions._post.assert_not_called()
async def test_live_mode(self, real_openai_chat_response): async def test_live_mode(self, real_openai_chat_response):
"""Test that live mode passes through to original methods.""" """Test that live mode passes through to original methods."""