update the code with aysnc iterator as suggested by Ben

This commit is contained in:
Sumit Jaiswal 2025-06-02 23:49:08 +05:30
parent b413c7562b
commit 3840ef7a98
No known key found for this signature in database
GPG key ID: A4604B39D64D6AEC
3 changed files with 22 additions and 7 deletions

View file

@ -602,7 +602,7 @@ class InferenceRouter(Inference):
async def health(self) -> dict[str, HealthResponse]: async def health(self) -> dict[str, HealthResponse]:
health_statuses = {} health_statuses = {}
timeout = 0.5 timeout = 1 # increasing the timeout to 1 second for health checks
for provider_id, impl in self.routing_table.impls_by_provider_id.items(): for provider_id, impl in self.routing_table.impls_by_provider_id.items():
try: try:
# check if the provider has a health method # check if the provider has a health method

View file

@ -313,10 +313,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
""" """
try: try:
client = self._create_client() if self.client is None else self.client client = self._create_client() if self.client is None else self.client
client.models.list() # Ensure the client is initialized _ = [m async for m in client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK) return HealthResponse(status=HealthStatus.OK)
except Exception as ex: except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(ex)}") return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def _get_model(self, model_id: str) -> Model: async def _get_model(self, model_id: str) -> Model:
if not self.model_store: if not self.model_store:

View file

@ -650,15 +650,23 @@ async def test_health_status_success(vllm_inference_adapter):
This test verifies that the health method returns a HealthResponse with status OK, only This test verifies that the health method returns a HealthResponse with status OK, only
when the connection to the vLLM server is successful. when the connection to the vLLM server is successful.
""" """
# Mock the client.models.list method to return successfully
# Set vllm_inference_adapter.client to None to ensure _create_client is called # Set vllm_inference_adapter.client to None to ensure _create_client is called
vllm_inference_adapter.client = None vllm_inference_adapter.client = None
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client: with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
# Create mock client and models # Create mock client and models
mock_client = MagicMock() mock_client = MagicMock()
mock_models = MagicMock() mock_models = MagicMock()
# Create a mock async iterator that yields a model when iterated
async def mock_list():
for model in [MagicMock()]:
yield model
# Set up the models.list to return our mock async iterator
mock_models.list.return_value = mock_list()
mock_client.models = mock_models mock_client.models = mock_models
mock_create_client.return_value = mock_client mock_create_client.return_value = mock_client
# Call the health method # Call the health method
health_response = await vllm_inference_adapter.health() health_response = await vllm_inference_adapter.health()
# Verify the response # Verify the response
@ -677,14 +685,21 @@ async def test_health_status_failure(vllm_inference_adapter):
and an appropriate error message when the connection to the vLLM server fails. and an appropriate error message when the connection to the vLLM server fails.
""" """
vllm_inference_adapter.client = None vllm_inference_adapter.client = None
# Mock the client.models.list method to raise an exception
with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client: with patch.object(vllm_inference_adapter, "_create_client") as mock_create_client:
# Create mock client and models # Create mock client and models
mock_client = MagicMock() mock_client = MagicMock()
mock_models = MagicMock() mock_models = MagicMock()
mock_models.list.side_effect = Exception("Connection failed")
# Create a mock async iterator that raises an exception when iterated
async def mock_list():
raise Exception("Connection failed")
yield # Unreachable code
# Set up the models.list to return our mock async iterator
mock_models.list.return_value = mock_list()
mock_client.models = mock_models mock_client.models = mock_models
mock_create_client.return_value = mock_client mock_create_client.return_value = mock_client
# Call the health method # Call the health method
health_response = await vllm_inference_adapter.health() health_response = await vllm_inference_adapter.health()
# Verify the response # Verify the response