mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
update the code with aysnc iterator as suggested by Ben
This commit is contained in:
parent
b413c7562b
commit
3840ef7a98
3 changed files with 22 additions and 7 deletions
|
@ -602,7 +602,7 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
async def health(self) -> dict[str, HealthResponse]:
|
async def health(self) -> dict[str, HealthResponse]:
|
||||||
health_statuses = {}
|
health_statuses = {}
|
||||||
timeout = 0.5
|
timeout = 1 # increasing the timeout to 1 second for health checks
|
||||||
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||||
try:
|
try:
|
||||||
# check if the provider has a health method
|
# check if the provider has a health method
|
||||||
|
|
|
@ -313,10 +313,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
client = self._create_client() if self.client is None else self.client
|
client = self._create_client() if self.client is None else self.client
|
||||||
client.models.list() # Ensure the client is initialized
|
_ = [m async for m in client.models.list()] # Ensure the client is initialized
|
||||||
return HealthResponse(status=HealthStatus.OK)
|
return HealthResponse(status=HealthStatus.OK)
|
||||||
except Exception as ex:
|
except Exception as e:
|
||||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(ex)}")
|
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||||
|
|
||||||
async def _get_model(self, model_id: str) -> Model:
|
async def _get_model(self, model_id: str) -> Model:
|
||||||
if not self.model_store:
|
if not self.model_store:
|
||||||
|
|
|
@ -650,15 +650,23 @@ async def test_health_status_success(vllm_inference_adapter):
|
||||||
This test verifies that the health method returns a HealthResponse with status OK, only
|
This test verifies that the health method returns a HealthResponse with status OK, only
|
||||||
when the connection to the vLLM server is successful.
|
when the connection to the vLLM server is successful.
|
||||||
"""
|
"""
|
||||||
# Mock the client.models.list method to return successfully
|
|
||||||
# Set vllm_inference_adapter.client to None to ensure _create_client is called
|
# Set vllm_inference_adapter.client to None to ensure _create_client is called
|
||||||
vllm_inference_adapter.client = None
|
vllm_inference_adapter.client = None
|
||||||
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
|
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
|
||||||
# Create mock client and models
|
# Create mock client and models
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_models = MagicMock()
|
mock_models = MagicMock()
|
||||||
|
|
||||||
|
# Create a mock async iterator that yields a model when iterated
|
||||||
|
async def mock_list():
|
||||||
|
for model in [MagicMock()]:
|
||||||
|
yield model
|
||||||
|
|
||||||
|
# Set up the models.list to return our mock async iterator
|
||||||
|
mock_models.list.return_value = mock_list()
|
||||||
mock_client.models = mock_models
|
mock_client.models = mock_models
|
||||||
mock_create_client.return_value = mock_client
|
mock_create_client.return_value = mock_client
|
||||||
|
|
||||||
# Call the health method
|
# Call the health method
|
||||||
health_response = await vllm_inference_adapter.health()
|
health_response = await vllm_inference_adapter.health()
|
||||||
# Verify the response
|
# Verify the response
|
||||||
|
@ -677,14 +685,21 @@ async def test_health_status_failure(vllm_inference_adapter):
|
||||||
and an appropriate error message when the connection to the vLLM server fails.
|
and an appropriate error message when the connection to the vLLM server fails.
|
||||||
"""
|
"""
|
||||||
vllm_inference_adapter.client = None
|
vllm_inference_adapter.client = None
|
||||||
# Mock the client.models.list method to raise an exception
|
|
||||||
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
|
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
|
||||||
# Create mock client and models
|
# Create mock client and models
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_models = MagicMock()
|
mock_models = MagicMock()
|
||||||
mock_models.list.side_effect = Exception("Connection failed")
|
|
||||||
|
# Create a mock async iterator that raises an exception when iterated
|
||||||
|
async def mock_list():
|
||||||
|
raise Exception("Connection failed")
|
||||||
|
yield # Unreachable code
|
||||||
|
|
||||||
|
# Set up the models.list to return our mock async iterator
|
||||||
|
mock_models.list.return_value = mock_list()
|
||||||
mock_client.models = mock_models
|
mock_client.models = mock_models
|
||||||
mock_create_client.return_value = mock_client
|
mock_create_client.return_value = mock_client
|
||||||
|
|
||||||
# Call the health method
|
# Call the health method
|
||||||
health_response = await vllm_inference_adapter.health()
|
health_response = await vllm_inference_adapter.health()
|
||||||
# Verify the response
|
# Verify the response
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue