mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
fix nvidia provider
This commit is contained in:
parent
597d405e13
commit
e2e8e7f399
1 changed files with 34 additions and 11 deletions
|
|
@ -5,15 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from . import NVIDIAConfig
|
||||
from .utils import _is_nvidia_hosted
|
||||
|
|
@ -21,9 +21,7 @@ from .utils import _is_nvidia_hosted
|
|||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||
config: NVIDIAConfig
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, ModelRegistryHelper):
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
|
||||
|
|
@ -37,12 +35,29 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
|
||||
"""
|
||||
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
"""Initialize the NVIDIA inference adapter with configuration."""
|
||||
# Initialize ModelRegistryHelper with empty model entries since NVIDIA uses dynamic model discovery
|
||||
ModelRegistryHelper.__init__(
|
||||
self, model_entries=[], allowed_models=config.allowed_models
|
||||
)
|
||||
self.config = config
|
||||
|
||||
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
|
||||
"nvidia/llama-3.2-nv-embedqa-1b-v2": {
|
||||
"embedding_dimension": 2048,
|
||||
"context_length": 8192,
|
||||
},
|
||||
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
|
||||
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
|
||||
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
|
||||
"nvidia/nv-embedqa-mistral-7b-v2": {
|
||||
"embedding_dimension": 512,
|
||||
"context_length": 4096,
|
||||
},
|
||||
"snowflake/arctic-embed-l": {
|
||||
"embedding_dimension": 512,
|
||||
"context_length": 1024,
|
||||
},
|
||||
}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -60,7 +75,9 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
|
||||
:return: The NVIDIA API key
|
||||
"""
|
||||
return self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY"
|
||||
return (
|
||||
self.config.api_key.get_secret_value() if self.config.api_key else "NO KEY"
|
||||
)
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
|
|
@ -68,7 +85,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
|
||||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||
return (
|
||||
f"{self.config.url}/v1"
|
||||
if self.config.append_api_version
|
||||
else self.config.url
|
||||
)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
|
@ -95,7 +116,9 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
response = await self.client.embeddings.create(
|
||||
model=await self._get_provider_model_id(model),
|
||||
input=input,
|
||||
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
|
||||
encoding_format=(
|
||||
encoding_format if encoding_format is not None else NOT_GIVEN
|
||||
),
|
||||
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
|
||||
user=user if user is not None else NOT_GIVEN,
|
||||
extra_body=extra_body,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue