mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore(recorder): update mocks to be closer to non-mock environment (#3442)
# What does this PR do? the @required_args decorator in openai-python is masking the async nature of the {AsyncCompletions,chat.AsyncCompletions}.create method. see https://github.com/openai/openai-python/issues/996 this means two things - 0. we cannot use iscoroutine in the recorder to detect async vs non 1. our mocks are inappropriately introducing identifiable async for (0), we update the iscoroutine check w/ detection of /v1/models, which is the only non-async function we mock & record. for (1), we could leave everything as is and assume (0) will catch errors. to be defensive, we update the unit tests to mock below create methods, allowing the true openai-python create() methods to be tested.
This commit is contained in:
parent
b6cb817897
commit
01bdcce4d2
2 changed files with 113 additions and 109 deletions
|
@ -155,27 +155,22 @@ class TestInferenceRecording:
|
|||
|
||||
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that recording mode captures and stores responses."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_recording_mode"
|
||||
with patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
|
||||
):
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
# Verify the response was returned correctly
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
# Verify the response was returned correctly
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
client.chat.completions._post.assert_called_once()
|
||||
|
||||
# Verify recording was stored
|
||||
storage = ResponseStorage(temp_storage_dir)
|
||||
|
@ -183,43 +178,38 @@ class TestInferenceRecording:
|
|||
|
||||
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
|
||||
"""Test that replay mode returns stored responses without making real calls."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_openai_chat_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_replay_mode"
|
||||
# First, record a response
|
||||
with patch(
|
||||
"openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
|
||||
):
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
client.chat.completions._post.assert_called_once()
|
||||
|
||||
# Now test replay mode - should not call the original method
|
||||
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model="llama3.2:3b",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
# Verify we got the recorded response
|
||||
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
|
||||
|
||||
# Verify the original method was NOT called
|
||||
mock_create_patch.assert_not_called()
|
||||
# Verify the original method was NOT called
|
||||
client.chat.completions._post.assert_not_called()
|
||||
|
||||
async def test_replay_mode_models(self, temp_storage_dir):
|
||||
"""Test that replay mode returns stored responses without making real model listing calls."""
|
||||
|
@ -272,43 +262,50 @@ class TestInferenceRecording:
|
|||
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
|
||||
"""Test recording and replay of embeddings calls."""
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_embeddings_response
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
encoding_format=NOT_GIVEN,
|
||||
)
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
client.embeddings._post.assert_called_once()
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
|
||||
# Record
|
||||
with patch(
|
||||
"openai.resources.embeddings.AsyncEmbeddings.create", new_callable=AsyncMock, side_effect=mock_create
|
||||
):
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
encoding_format=NOT_GIVEN,
|
||||
dimensions=NOT_GIVEN,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
encoding_format=NOT_GIVEN,
|
||||
dimensions=NOT_GIVEN,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
assert len(response.data) == 2
|
||||
assert len(response.data) == 2
|
||||
|
||||
# Replay
|
||||
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
|
||||
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
)
|
||||
response = await client.embeddings.create(
|
||||
model=real_embeddings_response.model,
|
||||
input=["Hello world", "Test embedding"],
|
||||
)
|
||||
|
||||
# Verify we got the recorded response
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
# Verify we got the recorded response
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0].embedding == [0.1, 0.2, 0.3]
|
||||
|
||||
# Verify original method was not called
|
||||
mock_create_patch.assert_not_called()
|
||||
# Verify original method was not called
|
||||
client.embeddings._post.assert_not_called()
|
||||
|
||||
async def test_completions_recording(self, temp_storage_dir):
|
||||
real_completions_response = OpenAICompletion(
|
||||
|
@ -326,40 +323,49 @@ class TestInferenceRecording:
|
|||
],
|
||||
)
|
||||
|
||||
async def mock_create(*args, **kwargs):
|
||||
return real_completions_response
|
||||
|
||||
temp_storage_dir = temp_storage_dir / "test_completions_recording"
|
||||
|
||||
# baseline - mock works without recording
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_called_once()
|
||||
|
||||
# Record
|
||||
with patch(
|
||||
"openai.resources.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
|
||||
):
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
user=NOT_GIVEN,
|
||||
)
|
||||
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_called_once()
|
||||
|
||||
# Replay
|
||||
with patch("openai.resources.completions.AsyncCompletions.create") as mock_create_patch:
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
mock_create_patch.assert_not_called()
|
||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
|
||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||
client.completions._post = AsyncMock(return_value=real_completions_response)
|
||||
response = await client.completions.create(
|
||||
model=real_completions_response.model,
|
||||
prompt="Hello, how are you?",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
)
|
||||
assert response.choices[0].text == real_completions_response.choices[0].text
|
||||
client.completions._post.assert_not_called()
|
||||
|
||||
async def test_live_mode(self, real_openai_chat_response):
|
||||
"""Test that live mode passes through to original methods."""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue