Fix rerank model endpoint issue

This commit is contained in:
Jiayi 2025-09-28 17:45:54 -07:00
parent f85743dcca
commit 2fb8756fe2
2 changed files with 24 additions and 8 deletions

View file

@ -138,10 +138,9 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
provider_model_id = await self._get_provider_model_id(model)
ranking_url = self.get_base_url()
model_obj = await self.model_store.get_model(model)
if _is_nvidia_hosted(self._config) and "endpoint" in model_obj.metadata:
ranking_url = model_obj.metadata["endpoint"]
if _is_nvidia_hosted(self._config) and provider_model_id in self._rerank_model_endpoints:
ranking_url = self._rerank_model_endpoints[provider_model_id]
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")

View file

@ -54,7 +54,7 @@ class MockSession:
return PostContext(self.response)
def create_adapter(config=None, model_metadata=None):
def create_adapter(config=None, rerank_endpoints=None):
if config is None:
config = NVIDIAConfig(api_key="test-key")
@ -62,11 +62,14 @@ def create_adapter(config=None, model_metadata=None):
class MockModel:
provider_resource_id = "test-model"
metadata = model_metadata or {}
metadata = {}
adapter.model_store = AsyncMock()
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
if rerank_endpoints is not None:
adapter._rerank_model_endpoints = rerank_endpoints
return adapter
@ -101,7 +104,7 @@ async def test_missing_rankings_key():
async def test_hosted_with_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), model_metadata={"endpoint": "https://model.endpoint/rerank"}
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"}
)
mock_session = MockSession(MockResponse())
@ -115,7 +118,7 @@ async def test_hosted_with_endpoint():
async def test_hosted_without_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
model_metadata={}, # No "endpoint" key
rerank_endpoints={}, # No endpoint mapping for test-model
)
mock_session = MockSession(MockResponse())
@ -126,10 +129,24 @@ async def test_hosted_without_endpoint():
assert "https://integrate.api.nvidia.com" in url
async def test_hosted_model_not_in_endpoint_mapping():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"}
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "https://integrate.api.nvidia.com" in url
assert url != "https://other.endpoint/rerank"
async def test_self_hosted_ignores_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
model_metadata={"endpoint": "https://model.endpoint/rerank"}, # This should be ignored.
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
)
mock_session = MockSession(MockResponse())