mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Fix rerank model endpoint issue
This commit is contained in:
parent
f85743dcca
commit
2fb8756fe2
2 changed files with 24 additions and 8 deletions
|
@ -138,10 +138,9 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
||||||
provider_model_id = await self._get_provider_model_id(model)
|
provider_model_id = await self._get_provider_model_id(model)
|
||||||
|
|
||||||
ranking_url = self.get_base_url()
|
ranking_url = self.get_base_url()
|
||||||
model_obj = await self.model_store.get_model(model)
|
|
||||||
|
|
||||||
if _is_nvidia_hosted(self._config) and "endpoint" in model_obj.metadata:
|
if _is_nvidia_hosted(self._config) and provider_model_id in self._rerank_model_endpoints:
|
||||||
ranking_url = model_obj.metadata["endpoint"]
|
ranking_url = self._rerank_model_endpoints[provider_model_id]
|
||||||
|
|
||||||
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
|
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ class MockSession:
|
||||||
return PostContext(self.response)
|
return PostContext(self.response)
|
||||||
|
|
||||||
|
|
||||||
def create_adapter(config=None, model_metadata=None):
|
def create_adapter(config=None, rerank_endpoints=None):
|
||||||
if config is None:
|
if config is None:
|
||||||
config = NVIDIAConfig(api_key="test-key")
|
config = NVIDIAConfig(api_key="test-key")
|
||||||
|
|
||||||
|
@ -62,11 +62,14 @@ def create_adapter(config=None, model_metadata=None):
|
||||||
|
|
||||||
class MockModel:
|
class MockModel:
|
||||||
provider_resource_id = "test-model"
|
provider_resource_id = "test-model"
|
||||||
metadata = model_metadata or {}
|
metadata = {}
|
||||||
|
|
||||||
adapter.model_store = AsyncMock()
|
adapter.model_store = AsyncMock()
|
||||||
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
|
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
|
||||||
|
|
||||||
|
if rerank_endpoints is not None:
|
||||||
|
adapter._rerank_model_endpoints = rerank_endpoints
|
||||||
|
|
||||||
return adapter
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
|
@ -101,7 +104,7 @@ async def test_missing_rankings_key():
|
||||||
|
|
||||||
async def test_hosted_with_endpoint():
|
async def test_hosted_with_endpoint():
|
||||||
adapter = create_adapter(
|
adapter = create_adapter(
|
||||||
config=NVIDIAConfig(api_key="key"), model_metadata={"endpoint": "https://model.endpoint/rerank"}
|
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"}
|
||||||
)
|
)
|
||||||
mock_session = MockSession(MockResponse())
|
mock_session = MockSession(MockResponse())
|
||||||
|
|
||||||
|
@ -115,7 +118,7 @@ async def test_hosted_with_endpoint():
|
||||||
async def test_hosted_without_endpoint():
|
async def test_hosted_without_endpoint():
|
||||||
adapter = create_adapter(
|
adapter = create_adapter(
|
||||||
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
|
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
|
||||||
model_metadata={}, # No "endpoint" key
|
rerank_endpoints={}, # No endpoint mapping for test-model
|
||||||
)
|
)
|
||||||
mock_session = MockSession(MockResponse())
|
mock_session = MockSession(MockResponse())
|
||||||
|
|
||||||
|
@ -126,10 +129,24 @@ async def test_hosted_without_endpoint():
|
||||||
assert "https://integrate.api.nvidia.com" in url
|
assert "https://integrate.api.nvidia.com" in url
|
||||||
|
|
||||||
|
|
||||||
|
async def test_hosted_model_not_in_endpoint_mapping():
|
||||||
|
adapter = create_adapter(
|
||||||
|
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"}
|
||||||
|
)
|
||||||
|
mock_session = MockSession(MockResponse())
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
url, _ = mock_session.post_calls[0]
|
||||||
|
assert "https://integrate.api.nvidia.com" in url
|
||||||
|
assert url != "https://other.endpoint/rerank"
|
||||||
|
|
||||||
|
|
||||||
async def test_self_hosted_ignores_endpoint():
|
async def test_self_hosted_ignores_endpoint():
|
||||||
adapter = create_adapter(
|
adapter = create_adapter(
|
||||||
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
||||||
model_metadata={"endpoint": "https://model.endpoint/rerank"}, # This should be ignored.
|
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
||||||
)
|
)
|
||||||
mock_session = MockSession(MockResponse())
|
mock_session = MockSession(MockResponse())
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue