Add rerank API for NVIDIA Inference Provider

This commit is contained in:
Jiayi 2025-09-03 17:34:05 -07:00
parent ce77c27ff8
commit bab9d7aaea
9 changed files with 9213 additions and 1 deletions

View file

@ -41,9 +41,14 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
Order,
RerankResponse,
StopReason,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.log import get_logger
@ -179,6 +184,25 @@ class InferenceRouter(Inference):
raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
"""Route rerank requests to the appropriate provider based on the model."""
logger.debug(f"InferenceRouter.rerank: {model}")
model_obj = await self._get_model(model, ModelType.rerank)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.rerank(
model=model_obj.identifier,
query=query,
items=items,
max_num_results=max_num_results,
)
async def openai_completion(
self,
model: str,