fix nvidia provider

This commit is contained in:
Kai Wu 2025-10-06 17:38:28 -07:00
parent 597d405e13
commit e2e8e7f399

View file

@ -5,15 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
from openai import NOT_GIVEN
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
OpenAIEmbeddingData, OpenAIEmbeddingData,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from openai import NOT_GIVEN
from . import NVIDIAConfig from . import NVIDIAConfig
from .utils import _is_nvidia_hosted from .utils import _is_nvidia_hosted
@ -21,9 +21,7 @@ from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference::nvidia") logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin): class NVIDIAInferenceAdapter(OpenAIMixin, ModelRegistryHelper):
config: NVIDIAConfig
""" """
NVIDIA Inference Adapter for Llama Stack. NVIDIA Inference Adapter for Llama Stack.
@ -37,12 +35,29 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning - ModelRegistryHelper.check_model_availability() just returns False and shows a warning
""" """
def __init__(self, config: NVIDIAConfig) -> None:
"""Initialize the NVIDIA inference adapter with configuration."""
# Initialize ModelRegistryHelper with empty model entries since NVIDIA uses dynamic model discovery
ModelRegistryHelper.__init__(
self, model_entries=[], allowed_models=config.allowed_models
)
self.config = config
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html # source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
embedding_model_metadata: dict[str, dict[str, int]] = { embedding_model_metadata: dict[str, dict[str, int]] = {
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192}, "nvidia/llama-3.2-nv-embedqa-1b-v2": {
"embedding_dimension": 2048,
"context_length": 8192,
},
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024}, "nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096}, "nvidia/nv-embedqa-mistral-7b-v2": {
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024}, "embedding_dimension": 512,
"context_length": 4096,
},
"snowflake/arctic-embed-l": {
"embedding_dimension": 512,
"context_length": 1024,
},
} }
async def initialize(self) -> None: async def initialize(self) -> None:
@ -60,7 +75,9 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
:return: The NVIDIA API key :return: The NVIDIA API key
""" """
return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY" return (
self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY"
)
def get_base_url(self) -> str: def get_base_url(self) -> str:
""" """
@ -68,7 +85,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
:return: The NVIDIA API base URL :return: The NVIDIA API base URL
""" """
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url return (
f"{self.config.url}/v1"
if self.config.append_api_version
else self.config.url
)
async def openai_embeddings( async def openai_embeddings(
self, self,
@ -95,7 +116,9 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
response = await self.client.embeddings.create( response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model), model=await self._get_provider_model_id(model),
input=input, input=input,
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN, encoding_format=(
encoding_format if encoding_format is not None else NOT_GIVEN
),
dimensions=dimensions if dimensions is not None else NOT_GIVEN, dimensions=dimensions if dimensions is not None else NOT_GIVEN,
user=user if user is not None else NOT_GIVEN, user=user if user is not None else NOT_GIVEN,
extra_body=extra_body, extra_body=extra_body,