mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 20:04:31 +00:00
Merge branch 'main' of https://github.com/meta-llama/llama-stack into add_nemo_customizer
This commit is contained in:
commit
f534b4c2ea
571 changed files with 229651 additions and 12956 deletions
|
|
@ -6,9 +6,10 @@
|
|||
|
||||
import logging
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI
|
||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
|
@ -40,15 +41,17 @@ from llama_stack.models.llama.datatypes import (
|
|||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
from . import NVIDIAConfig
|
||||
from .models import _MODEL_ENTRIES
|
||||
from .models import MODEL_ENTRIES
|
||||
from .openai_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_completion_request,
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
convert_openai_completion_choice,
|
||||
convert_openai_completion_stream,
|
||||
)
|
||||
|
|
@ -60,7 +63,7 @@ logger = logging.getLogger(__name__)
|
|||
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
|
||||
|
|
@ -80,22 +83,54 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
# )
|
||||
|
||||
self._config = config
|
||||
# make sure the client lives longer than any async calls
|
||||
self._client = AsyncOpenAI(
|
||||
base_url=f"{self._config.url}/v1",
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
|
||||
"""
|
||||
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
|
||||
some models are hosted on different URLs. This function returns the appropriate client
|
||||
for the given provider_model_id.
|
||||
|
||||
This relies on lru_cache and self._default_client to avoid creating a new client for each request
|
||||
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
|
||||
|
||||
:param provider_model_id: The provider model ID
|
||||
:return: An OpenAI client
|
||||
"""
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
|
||||
"""
|
||||
Maintain a single OpenAI client per base_url.
|
||||
"""
|
||||
return AsyncOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
special_model_urls = {
|
||||
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
|
||||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||
}
|
||||
|
||||
base_url = f"{self._config.url}/v1"
|
||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||
base_url = special_model_urls[provider_model_id]
|
||||
|
||||
return _get_client_for_base_url(base_url)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if content_has_media(content):
|
||||
raise NotImplementedError("Media is not supported")
|
||||
|
||||
|
|
@ -103,9 +138,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
# removing this health check as NeMo customizer endpoint health check is returning 404
|
||||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
request = convert_completion_request(
|
||||
request=CompletionRequest(
|
||||
model=self.get_provider_model_id(model_id),
|
||||
model=provider_model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
|
@ -116,7 +152,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._client.completions.create(**request)
|
||||
response = await self._get_client(provider_model_id).completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -144,19 +180,38 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
#
|
||||
# we can ignore str and always pass List[str] to OpenAI
|
||||
#
|
||||
flat_contents = [
|
||||
item.text if isinstance(item, TextContentItem) else item
|
||||
for content in contents
|
||||
for item in (content if isinstance(content, list) else [content])
|
||||
]
|
||||
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
|
||||
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
||||
model = self.get_provider_model_id(model_id)
|
||||
|
||||
response = await self._client.embeddings.create(
|
||||
model=model,
|
||||
input=input,
|
||||
# extra_body={"input_type": "passage"|"query"}, # TODO(mf): how to tell caller's intent?
|
||||
)
|
||||
extra_body = {}
|
||||
|
||||
if text_truncation is not None:
|
||||
text_truncation_options = {
|
||||
TextTruncation.none: "NONE",
|
||||
TextTruncation.end: "END",
|
||||
TextTruncation.start: "START",
|
||||
}
|
||||
extra_body["truncate"] = text_truncation_options[text_truncation]
|
||||
|
||||
if output_dimension is not None:
|
||||
extra_body["dimensions"] = output_dimension
|
||||
|
||||
if task_type is not None:
|
||||
task_type_options = {
|
||||
EmbeddingTaskType.document: "passage",
|
||||
EmbeddingTaskType.query: "query",
|
||||
}
|
||||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
response = await self._client.embeddings.create(
|
||||
model=model,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
except BadRequestError as e:
|
||||
raise ValueError(f"Failed to get embeddings: {e}") from e
|
||||
|
||||
#
|
||||
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...)
|
||||
|
|
@ -169,7 +224,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
|
|
@ -178,11 +233,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if tool_prompt_format:
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
|
||||
|
||||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
request = await convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=self.get_provider_model_id(model_id),
|
||||
|
|
@ -198,12 +256,12 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._client.chat.completions.create(**request)
|
||||
response = await self._get_client(provider_model_id).chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
if stream:
|
||||
return convert_openai_chat_completion_stream(response)
|
||||
return convert_openai_chat_completion_stream(response, enable_incremental_tool_calls=False)
|
||||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_chat_completion_choice(response.choices[0])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue