mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Add rerank models to the dynamic model list; Fix integration tests
This commit is contained in:
parent
3538477070
commit
816b68fdc7
8 changed files with 247 additions and 25 deletions
|
@ -204,6 +204,6 @@ rerank_response = client.inference.rerank(
|
|||
],
|
||||
)
|
||||
|
||||
for i, result in enumerate(rerank_response.data):
|
||||
print(f"{i+1}. [Index: {result.index}, Score: {result.relevance_score:.3f}]")
|
||||
for i, result in enumerate(rerank_response):
|
||||
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
|
||||
```
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
@ -51,6 +52,18 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
|
||||
}
|
||||
|
||||
rerank_model_list = [
|
||||
"nv-rerank-qa-mistral-4b:1",
|
||||
"nvidia/nv-rerankqa-mistral-4b-v3",
|
||||
"nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||
]
|
||||
|
||||
_rerank_model_endpoints = {
|
||||
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
|
||||
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
|
||||
}
|
||||
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
|
||||
|
@ -69,6 +82,8 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
# "Consider removing the api_key from the configuration."
|
||||
# )
|
||||
|
||||
super().__init__()
|
||||
|
||||
self._config = config
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
|
@ -87,6 +102,30 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
"""
|
||||
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
"""
|
||||
List available NVIDIA models by combining:
|
||||
1. Dynamic models from https://integrate.api.nvidia.com/v1/models
|
||||
2. Static rerank models (which use different API endpoints)
|
||||
"""
|
||||
models = await super().list_models() or []
|
||||
|
||||
existing_ids = {m.identifier for m in models}
|
||||
for model_id, _ in self._rerank_model_endpoints.items():
|
||||
if self.allowed_models and model_id not in self.allowed_models:
|
||||
continue
|
||||
if model_id not in existing_ids:
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=model_id,
|
||||
identifier=model_id,
|
||||
model_type=ModelType.rerank,
|
||||
)
|
||||
models.append(model)
|
||||
self._model_cache[model_id] = model
|
||||
|
||||
return models
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue