mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Fix rerank model endpoint issue
This commit is contained in:
parent
f85743dcca
commit
2fb8756fe2
2 changed files with 24 additions and 8 deletions
|
@ -138,10 +138,9 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
ranking_url = self.get_base_url()
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
if _is_nvidia_hosted(self._config) and "endpoint" in model_obj.metadata:
|
||||
ranking_url = model_obj.metadata["endpoint"]
|
||||
if _is_nvidia_hosted(self._config) and provider_model_id in self._rerank_model_endpoints:
|
||||
ranking_url = self._rerank_model_endpoints[provider_model_id]
|
||||
|
||||
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ class MockSession:
|
|||
return PostContext(self.response)
|
||||
|
||||
|
||||
def create_adapter(config=None, model_metadata=None):
|
||||
def create_adapter(config=None, rerank_endpoints=None):
|
||||
if config is None:
|
||||
config = NVIDIAConfig(api_key="test-key")
|
||||
|
||||
|
@ -62,11 +62,14 @@ def create_adapter(config=None, model_metadata=None):
|
|||
|
||||
class MockModel:
|
||||
provider_resource_id = "test-model"
|
||||
metadata = model_metadata or {}
|
||||
metadata = {}
|
||||
|
||||
adapter.model_store = AsyncMock()
|
||||
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
|
||||
|
||||
if rerank_endpoints is not None:
|
||||
adapter._rerank_model_endpoints = rerank_endpoints
|
||||
|
||||
return adapter
|
||||
|
||||
|
||||
|
@ -101,7 +104,7 @@ async def test_missing_rankings_key():
|
|||
|
||||
async def test_hosted_with_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), model_metadata={"endpoint": "https://model.endpoint/rerank"}
|
||||
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"}
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
|
@ -115,7 +118,7 @@ async def test_hosted_with_endpoint():
|
|||
async def test_hosted_without_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
|
||||
model_metadata={}, # No "endpoint" key
|
||||
rerank_endpoints={}, # No endpoint mapping for test-model
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
|
@ -126,10 +129,24 @@ async def test_hosted_without_endpoint():
|
|||
assert "https://integrate.api.nvidia.com" in url
|
||||
|
||||
|
||||
async def test_hosted_model_not_in_endpoint_mapping():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"}
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "https://integrate.api.nvidia.com" in url
|
||||
assert url != "https://other.endpoint/rerank"
|
||||
|
||||
|
||||
async def test_self_hosted_ignores_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
||||
model_metadata={"endpoint": "https://model.endpoint/rerank"}, # This should be ignored.
|
||||
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue