Add rerank models to the dynamic model list; Fix integration tests

This commit is contained in:
Jiayi 2025-09-28 14:45:16 -07:00
parent 3538477070
commit 816b68fdc7
8 changed files with 247 additions and 25 deletions

View file

@ -204,6 +204,6 @@ rerank_response = client.inference.rerank(
],
)
for i, result in enumerate(rerank_response.data):
print(f"{i+1}. [Index: {result.index}, Score: {result.relevance_score:.3f}]")
for i, result in enumerate(rerank_response):
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
```

View file

@ -20,6 +20,7 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -51,6 +52,18 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
}
rerank_model_list = [
"nv-rerank-qa-mistral-4b:1",
"nvidia/nv-rerankqa-mistral-4b-v3",
"nvidia/llama-3.2-nv-rerankqa-1b-v2",
]
_rerank_model_endpoints = {
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
}
def __init__(self, config: NVIDIAConfig) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
@ -69,6 +82,8 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
# "Consider removing the api_key from the configuration."
# )
super().__init__()
self._config = config
def get_api_key(self) -> str:
@ -87,6 +102,30 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
"""
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
async def list_models(self) -> list[Model] | None:
"""
List available NVIDIA models by combining:
1. Dynamic models from https://integrate.api.nvidia.com/v1/models
2. Static rerank models (which use different API endpoints)
"""
models = await super().list_models() or []
existing_ids = {m.identifier for m in models}
for model_id, _ in self._rerank_model_endpoints.items():
if self.allowed_models and model_id not in self.allowed_models:
continue
if model_id not in existing_ids:
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=model_id,
identifier=model_id,
model_type=ModelType.rerank,
)
models.append(model)
self._model_cache[model_id] = model
return models
async def rerank(
self,
model: str,