diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index f899d73d3..674016fb1 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -7,7 +7,6 @@ from __future__ import annotations # for forward references import hashlib -import inspect import json import os from collections.abc import Generator @@ -243,11 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint global _current_mode, _current_storage if _current_mode == InferenceMode.LIVE or _current_storage is None: - # Normal operation - if inspect.iscoroutinefunction(original_method): - return await original_method(self, *args, **kwargs) - else: + if endpoint == "/v1/models": return original_method(self, *args, **kwargs) + else: + return await original_method(self, *args, **kwargs) # Get base URL based on client type if client_type == "openai": @@ -298,10 +296,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint ) elif _current_mode == InferenceMode.RECORD: - if inspect.iscoroutinefunction(original_method): - response = await original_method(self, *args, **kwargs) - else: + if endpoint == "/v1/models": response = original_method(self, *args, **kwargs) + else: + response = await original_method(self, *args, **kwargs) # we want to store the result of the iterator, not the iterator itself if endpoint == "/v1/models": diff --git a/tests/unit/distribution/test_inference_recordings.py b/tests/unit/distribution/test_inference_recordings.py index 4909bbe1e..5740357c1 100644 --- a/tests/unit/distribution/test_inference_recordings.py +++ b/tests/unit/distribution/test_inference_recordings.py @@ -155,27 +155,22 @@ class TestInferenceRecording: async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response): """Test that recording mode captures and stores responses.""" - - async def mock_create(*args, **kwargs): - return real_openai_chat_response - temp_storage_dir = temp_storage_dir / "test_recording_mode" - with patch( - "openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create - ): - with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response) - response = await client.chat.completions.create( - model="llama3.2:3b", - messages=[{"role": "user", "content": "Hello, how are you?"}], - temperature=0.7, - max_tokens=50, - user=NOT_GIVEN, - ) + response = await client.chat.completions.create( + model="llama3.2:3b", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=0.7, + max_tokens=50, + user=NOT_GIVEN, + ) - # Verify the response was returned correctly - assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking." + # Verify the response was returned correctly + assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking." + client.chat.completions._post.assert_called_once() # Verify recording was stored storage = ResponseStorage(temp_storage_dir) @@ -183,43 +178,38 @@ class TestInferenceRecording: async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response): """Test that replay mode returns stored responses without making real calls.""" - - async def mock_create(*args, **kwargs): - return real_openai_chat_response - temp_storage_dir = temp_storage_dir / "test_replay_mode" # First, record a response - with patch( - "openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create - ): - with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response) - response = await client.chat.completions.create( - model="llama3.2:3b", - messages=[{"role": "user", "content": "Hello, how are you?"}], - temperature=0.7, - max_tokens=50, - user=NOT_GIVEN, - ) + response = await client.chat.completions.create( + model="llama3.2:3b", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=0.7, + max_tokens=50, + user=NOT_GIVEN, + ) + client.chat.completions._post.assert_called_once() # Now test replay mode - should not call the original method - with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch: - with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response) - response = await client.chat.completions.create( - model="llama3.2:3b", - messages=[{"role": "user", "content": "Hello, how are you?"}], - temperature=0.7, - max_tokens=50, - ) + response = await client.chat.completions.create( + model="llama3.2:3b", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=0.7, + max_tokens=50, + ) - # Verify we got the recorded response - assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking." + # Verify we got the recorded response + assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking." - # Verify the original method was NOT called - mock_create_patch.assert_not_called() + # Verify the original method was NOT called + client.chat.completions._post.assert_not_called() async def test_replay_mode_models(self, temp_storage_dir): """Test that replay mode returns stored responses without making real model listing calls.""" @@ -272,43 +262,50 @@ class TestInferenceRecording: async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response): """Test recording and replay of embeddings calls.""" - async def mock_create(*args, **kwargs): - return real_embeddings_response + # baseline - mock works without recording + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.embeddings._post = AsyncMock(return_value=real_embeddings_response) + response = await client.embeddings.create( + model=real_embeddings_response.model, + input=["Hello world", "Test embedding"], + encoding_format=NOT_GIVEN, + ) + assert len(response.data) == 2 + assert response.data[0].embedding == [0.1, 0.2, 0.3] + client.embeddings._post.assert_called_once() temp_storage_dir = temp_storage_dir / "test_embeddings_recording" # Record - with patch( - "openai.resources.embeddings.AsyncEmbeddings.create", new_callable=AsyncMock, side_effect=mock_create - ): - with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.embeddings._post = AsyncMock(return_value=real_embeddings_response) - response = await client.embeddings.create( - model=real_embeddings_response.model, - input=["Hello world", "Test embedding"], - encoding_format=NOT_GIVEN, - dimensions=NOT_GIVEN, - user=NOT_GIVEN, - ) + response = await client.embeddings.create( + model=real_embeddings_response.model, + input=["Hello world", "Test embedding"], + encoding_format=NOT_GIVEN, + dimensions=NOT_GIVEN, + user=NOT_GIVEN, + ) - assert len(response.data) == 2 + assert len(response.data) == 2 # Replay - with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch: - with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.embeddings._post = AsyncMock(return_value=real_embeddings_response) - response = await client.embeddings.create( - model=real_embeddings_response.model, - input=["Hello world", "Test embedding"], - ) + response = await client.embeddings.create( + model=real_embeddings_response.model, + input=["Hello world", "Test embedding"], + ) - # Verify we got the recorded response - assert len(response.data) == 2 - assert response.data[0].embedding == [0.1, 0.2, 0.3] + # Verify we got the recorded response + assert len(response.data) == 2 + assert response.data[0].embedding == [0.1, 0.2, 0.3] - # Verify original method was not called - mock_create_patch.assert_not_called() + # Verify original method was not called + client.embeddings._post.assert_not_called() async def test_completions_recording(self, temp_storage_dir): real_completions_response = OpenAICompletion( @@ -326,40 +323,49 @@ class TestInferenceRecording: ], ) - async def mock_create(*args, **kwargs): - return real_completions_response - temp_storage_dir = temp_storage_dir / "test_completions_recording" + # baseline - mock works without recording + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.completions._post = AsyncMock(return_value=real_completions_response) + response = await client.completions.create( + model=real_completions_response.model, + prompt="Hello, how are you?", + temperature=0.7, + max_tokens=50, + user=NOT_GIVEN, + ) + assert response.choices[0].text == real_completions_response.choices[0].text + client.completions._post.assert_called_once() + # Record - with patch( - "openai.resources.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create - ): - with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.completions._post = AsyncMock(return_value=real_completions_response) - response = await client.completions.create( - model=real_completions_response.model, - prompt="Hello, how are you?", - temperature=0.7, - max_tokens=50, - user=NOT_GIVEN, - ) + response = await client.completions.create( + model=real_completions_response.model, + prompt="Hello, how are you?", + temperature=0.7, + max_tokens=50, + user=NOT_GIVEN, + ) - assert response.choices[0].text == real_completions_response.choices[0].text + assert response.choices[0].text == real_completions_response.choices[0].text + client.completions._post.assert_called_once() # Replay - with patch("openai.resources.completions.AsyncCompletions.create") as mock_create_patch: - with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): - client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") - response = await client.completions.create( - model=real_completions_response.model, - prompt="Hello, how are you?", - temperature=0.7, - max_tokens=50, - ) - assert response.choices[0].text == real_completions_response.choices[0].text - mock_create_patch.assert_not_called() + with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)): + client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test") + client.completions._post = AsyncMock(return_value=real_completions_response) + response = await client.completions.create( + model=real_completions_response.model, + prompt="Hello, how are you?", + temperature=0.7, + max_tokens=50, + ) + assert response.choices[0].text == real_completions_response.choices[0].text + client.completions._post.assert_not_called() async def test_live_mode(self, real_openai_chat_response): """Test that live mode passes through to original methods."""