forked from phoenix-oss/llama-stack-mirror
# What does this PR do? add /v1/inference/embeddings implementation to NVIDIA provider **open topics** - - *asymmetric models*. NeMo Retriever includes asymmetric models, which are models that embed differently depending on if the input is destined for storage or lookup against storage. the /v1/inference/embeddings api does not allow the user to indicate the type of embedding to perform. see https://github.com/meta-llama/llama-stack/issues/934 - *truncation*. embedding models typically have a limited context window, e.g. 1024 tokens is common though newer models have 8k windows. when the input is larger than this window the endpoint cannot perform its designed function. two options: 0. return an error so the user can reduce the input size and retry; 1. perform truncation for the user and proceed (common strategies are left or right truncation). many users encounter context window size limits and will struggle to write reliable programs. this struggle is especially acute without access to the model's tokenizer. the /v1/inference/embeddings api does not allow the user to delegate truncation policy. see https://github.com/meta-llama/llama-stack/issues/933 - *dimensions*. "Matryoshka" embedding models are available. they allow users to control the number of embedding dimensions the model produces. this is a critical feature for managing storage constraints. embeddings of 1024 dimensions what achieve 95% recall for an application may not be worth the storage cost if a 512 dimensions can achieve 93% recall. controlling embedding dimensions allows applications to determine their recall and storage tradeoffs. the /v1/inference/embeddings api does not allow the user to control the output dimensions. see https://github.com/meta-llama/llama-stack/issues/932 ## Test Plan - `llama stack run llama_stack/templates/nvidia/run.yaml` - `LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/inference/test_embedding.py --embedding-model baai/bge-m3` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [x] Wrote necessary unit or integration tests. --------- Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
198 lines
7.2 KiB
Python
198 lines
7.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import logging
|
|
import warnings
|
|
from typing import AsyncIterator, List, Optional, Union
|
|
|
|
from openai import APIConnectionError, AsyncOpenAI
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
InterleavedContent,
|
|
InterleavedContentItem,
|
|
TextContentItem,
|
|
)
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
EmbeddingsResponse,
|
|
Inference,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
)
|
|
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
|
|
|
from . import NVIDIAConfig
|
|
from .models import _MODEL_ENTRIES
|
|
from .openai_utils import (
|
|
convert_chat_completion_request,
|
|
convert_completion_request,
|
|
convert_openai_chat_completion_choice,
|
|
convert_openai_chat_completion_stream,
|
|
convert_openai_completion_choice,
|
|
convert_openai_completion_stream,
|
|
)
|
|
from .utils import _is_nvidia_hosted, check_health
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|
def __init__(self, config: NVIDIAConfig) -> None:
|
|
# TODO(mf): filter by available models
|
|
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
|
|
|
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
|
|
|
if _is_nvidia_hosted(config):
|
|
if not config.api_key:
|
|
raise RuntimeError(
|
|
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
|
|
)
|
|
# elif self._config.api_key:
|
|
#
|
|
# we don't raise this warning because a user may have deployed their
|
|
# self-hosted NIM with an API key requirement.
|
|
#
|
|
# warnings.warn(
|
|
# "API key is not required for self-hosted NVIDIA NIM. "
|
|
# "Consider removing the api_key from the configuration."
|
|
# )
|
|
|
|
self._config = config
|
|
# make sure the client lives longer than any async calls
|
|
self._client = AsyncOpenAI(
|
|
base_url=f"{self._config.url}/v1",
|
|
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
|
timeout=self._config.timeout,
|
|
)
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
|
if content_has_media(content):
|
|
raise NotImplementedError("Media is not supported")
|
|
|
|
await check_health(self._config) # this raises errors
|
|
|
|
request = convert_completion_request(
|
|
request=CompletionRequest(
|
|
model=self.get_provider_model_id(model_id),
|
|
content=content,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
),
|
|
n=1,
|
|
)
|
|
|
|
try:
|
|
response = await self._client.completions.create(**request)
|
|
except APIConnectionError as e:
|
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
|
|
|
if stream:
|
|
return convert_openai_completion_stream(response)
|
|
else:
|
|
# we pass n=1 to get only one completion
|
|
return convert_openai_completion_choice(response.choices[0])
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[str] | List[InterleavedContentItem],
|
|
) -> EmbeddingsResponse:
|
|
if any(content_has_media(content) for content in contents):
|
|
raise NotImplementedError("Media is not supported")
|
|
|
|
#
|
|
# Llama Stack: contents = List[str] | List[InterleavedContentItem]
|
|
# ->
|
|
# OpenAI: input = str | List[str]
|
|
#
|
|
# we can ignore str and always pass List[str] to OpenAI
|
|
#
|
|
flat_contents = [
|
|
item.text if isinstance(item, TextContentItem) else item
|
|
for content in contents
|
|
for item in (content if isinstance(content, list) else [content])
|
|
]
|
|
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
|
model = self.get_provider_model_id(model_id)
|
|
|
|
response = await self._client.embeddings.create(
|
|
model=model,
|
|
input=input,
|
|
# extra_body={"input_type": "passage"|"query"}, # TODO(mf): how to tell caller's intent?
|
|
)
|
|
|
|
#
|
|
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...)
|
|
# ->
|
|
# Llama Stack: EmbeddingsResponse(embeddings=List[List[float]])
|
|
#
|
|
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
|
if tool_prompt_format:
|
|
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
|
|
|
|
await check_health(self._config) # this raises errors
|
|
|
|
request = await convert_chat_completion_request(
|
|
request=ChatCompletionRequest(
|
|
model=self.get_provider_model_id(model_id),
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
response_format=response_format,
|
|
tools=tools,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
tool_config=tool_config,
|
|
),
|
|
n=1,
|
|
)
|
|
|
|
try:
|
|
response = await self._client.chat.completions.create(**request)
|
|
except APIConnectionError as e:
|
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
|
|
|
if stream:
|
|
return convert_openai_chat_completion_stream(response)
|
|
else:
|
|
# we pass n=1 to get only one completion
|
|
return convert_openai_chat_completion_choice(response.choices[0])
|