From 2fb8756fe2e1f4d157d2bcae0363bd74345ca9dc Mon Sep 17 00:00:00 2001 From: Jiayi Date: Sun, 28 Sep 2025 17:45:54 -0700 Subject: [PATCH] Fix rerank model endpoint issue --- .../remote/inference/nvidia/nvidia.py | 5 ++-- .../providers/nvidia/test_rerank_inference.py | 27 +++++++++++++++---- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index f6fca4014..15e50ff97 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -138,10 +138,9 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference): provider_model_id = await self._get_provider_model_id(model) ranking_url = self.get_base_url() - model_obj = await self.model_store.get_model(model) - if _is_nvidia_hosted(self._config) and "endpoint" in model_obj.metadata: - ranking_url = model_obj.metadata["endpoint"] + if _is_nvidia_hosted(self._config) and provider_model_id in self._rerank_model_endpoints: + ranking_url = self._rerank_model_endpoints[provider_model_id] logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}") diff --git a/tests/unit/providers/nvidia/test_rerank_inference.py b/tests/unit/providers/nvidia/test_rerank_inference.py index f34518609..60891e496 100644 --- a/tests/unit/providers/nvidia/test_rerank_inference.py +++ b/tests/unit/providers/nvidia/test_rerank_inference.py @@ -54,7 +54,7 @@ class MockSession: return PostContext(self.response) -def create_adapter(config=None, model_metadata=None): +def create_adapter(config=None, rerank_endpoints=None): if config is None: config = NVIDIAConfig(api_key="test-key") @@ -62,11 +62,14 @@ def create_adapter(config=None, model_metadata=None): class MockModel: provider_resource_id = "test-model" - metadata = model_metadata or {} + metadata = {} adapter.model_store = AsyncMock() adapter.model_store.get_model = AsyncMock(return_value=MockModel()) + if rerank_endpoints is not None: + adapter._rerank_model_endpoints = rerank_endpoints + return adapter @@ -101,7 +104,7 @@ async def test_missing_rankings_key(): async def test_hosted_with_endpoint(): adapter = create_adapter( - config=NVIDIAConfig(api_key="key"), model_metadata={"endpoint": "https://model.endpoint/rerank"} + config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"} ) mock_session = MockSession(MockResponse()) @@ -115,7 +118,7 @@ async def test_hosted_with_endpoint(): async def test_hosted_without_endpoint(): adapter = create_adapter( config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com). - model_metadata={}, # No "endpoint" key + rerank_endpoints={}, # No endpoint mapping for test-model ) mock_session = MockSession(MockResponse()) @@ -126,10 +129,24 @@ async def test_hosted_without_endpoint(): assert "https://integrate.api.nvidia.com" in url +async def test_hosted_model_not_in_endpoint_mapping(): + adapter = create_adapter( + config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"} + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert "https://integrate.api.nvidia.com" in url + assert url != "https://other.endpoint/rerank" + + async def test_self_hosted_ignores_endpoint(): adapter = create_adapter( config=NVIDIAConfig(url="http://localhost:8000", api_key=None), - model_metadata={"endpoint": "https://model.endpoint/rerank"}, # This should be ignored. + rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted. ) mock_session = MockSession(MockResponse())