mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-02 20:40:36 +00:00
update the code with aysnc iterator as suggested by Ben
This commit is contained in:
parent
b413c7562b
commit
3840ef7a98
3 changed files with 22 additions and 7 deletions
|
@ -650,15 +650,23 @@ async def test_health_status_success(vllm_inference_adapter):
|
|||
This test verifies that the health method returns a HealthResponse with status OK, only
|
||||
when the connection to the vLLM server is successful.
|
||||
"""
|
||||
# Mock the client.models.list method to return successfully
|
||||
# Set vllm_inference_adapter.client to None to ensure _create_client is called
|
||||
vllm_inference_adapter.client = None
|
||||
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
|
||||
# Create mock client and models
|
||||
mock_client = MagicMock()
|
||||
mock_models = MagicMock()
|
||||
|
||||
# Create a mock async iterator that yields a model when iterated
|
||||
async def mock_list():
|
||||
for model in [MagicMock()]:
|
||||
yield model
|
||||
|
||||
# Set up the models.list to return our mock async iterator
|
||||
mock_models.list.return_value = mock_list()
|
||||
mock_client.models = mock_models
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
# Call the health method
|
||||
health_response = await vllm_inference_adapter.health()
|
||||
# Verify the response
|
||||
|
@ -677,14 +685,21 @@ async def test_health_status_failure(vllm_inference_adapter):
|
|||
and an appropriate error message when the connection to the vLLM server fails.
|
||||
"""
|
||||
vllm_inference_adapter.client = None
|
||||
# Mock the client.models.list method to raise an exception
|
||||
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
|
||||
# Create mock client and models
|
||||
mock_client = MagicMock()
|
||||
mock_models = MagicMock()
|
||||
mock_models.list.side_effect = Exception("Connection failed")
|
||||
|
||||
# Create a mock async iterator that raises an exception when iterated
|
||||
async def mock_list():
|
||||
raise Exception("Connection failed")
|
||||
yield # Unreachable code
|
||||
|
||||
# Set up the models.list to return our mock async iterator
|
||||
mock_models.list.return_value = mock_list()
|
||||
mock_client.models = mock_models
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
# Call the health method
|
||||
health_response = await vllm_inference_adapter.health()
|
||||
# Verify the response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue