mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
add NVIDIA NIM inference adapter (#355)
# What does this PR do? this PR adds a basic inference adapter to NVIDIA NIMs what it does - - chat completion api - tool calls - streaming - structured output - logprobs - support hosted NIM on integrate.api.nvidia.com - support downloaded NIM containers what it does not do - - completion api - embedding api - vision models - builtin tools - have certainty that sampling strategies are correct ## Feature/Issue validation/testing/test plan `pytest -s -v --providers inference=nvidia llama_stack/providers/tests/inference/ --env NVIDIA_API_KEY=...` all tests should pass. there are pydantic v1 warnings. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Was this discussed/approved via a Github issue? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? - [x] Did you write any new necessary tests? Thanks for contributing 🎉!
This commit is contained in:
parent
2cfc41e13b
commit
4e6c984c26
10 changed files with 934 additions and 10 deletions
|
@ -150,4 +150,15 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="nvidia",
|
||||||
|
pip_packages=[
|
||||||
|
"openai",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.remote.inference.nvidia",
|
||||||
|
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
22
llama_stack/providers/remote/inference/nvidia/__init__.py
Normal file
22
llama_stack/providers/remote/inference/nvidia/__init__.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
|
||||||
|
from .config import NVIDIAConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
|
||||||
|
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||||
|
from .nvidia import NVIDIAInferenceAdapter
|
||||||
|
|
||||||
|
if not isinstance(config, NVIDIAConfig):
|
||||||
|
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||||
|
adapter = NVIDIAInferenceAdapter(config)
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["get_adapter_impl", "NVIDIAConfig"]
|
48
llama_stack/providers/remote/inference/nvidia/config.py
Normal file
48
llama_stack/providers/remote/inference/nvidia/config.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class NVIDIAConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for the NVIDIA NIM inference endpoint.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
||||||
|
api_key (str): The access key for the hosted NIM endpoints
|
||||||
|
|
||||||
|
There are two ways to access NVIDIA NIMs -
|
||||||
|
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
||||||
|
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
|
||||||
|
|
||||||
|
By default the configuration is set to use the hosted APIs. This requires
|
||||||
|
an API key which can be obtained from https://ngc.nvidia.com/.
|
||||||
|
|
||||||
|
By default the configuration will attempt to read the NVIDIA_API_KEY environment
|
||||||
|
variable to set the api_key. Please do not put your API key in code.
|
||||||
|
|
||||||
|
If you are using a self-hosted NVIDIA NIM, you can set the url to the
|
||||||
|
URL of your running NVIDIA NIM and do not need to set the api_key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
url: str = Field(
|
||||||
|
default="https://integrate.api.nvidia.com",
|
||||||
|
description="A base url for accessing the NVIDIA NIM",
|
||||||
|
)
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
||||||
|
description="The NVIDIA API key, only needed of using the hosted service",
|
||||||
|
)
|
||||||
|
timeout: int = Field(
|
||||||
|
default=60,
|
||||||
|
description="Timeout for the HTTP requests",
|
||||||
|
)
|
183
llama_stack/providers/remote/inference/nvidia/nvidia.py
Normal file
183
llama_stack/providers/remote/inference/nvidia/nvidia.py
Normal file
|
@ -0,0 +1,183 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
|
from llama_models.datatypes import SamplingParams
|
||||||
|
from llama_models.llama3.api.datatypes import (
|
||||||
|
InterleavedTextMedia,
|
||||||
|
Message,
|
||||||
|
ToolChoice,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_models.sku_list import CoreModelId
|
||||||
|
from openai import APIConnectionError, AsyncOpenAI
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
CompletionResponse,
|
||||||
|
CompletionResponseStreamChunk,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
Inference,
|
||||||
|
LogProbConfig,
|
||||||
|
ResponseFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
build_model_alias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
|
|
||||||
|
from . import NVIDIAConfig
|
||||||
|
from .openai_utils import (
|
||||||
|
convert_chat_completion_request,
|
||||||
|
convert_openai_chat_completion_choice,
|
||||||
|
convert_openai_chat_completion_stream,
|
||||||
|
)
|
||||||
|
from .utils import _is_nvidia_hosted, check_health
|
||||||
|
|
||||||
|
_MODEL_ALIASES = [
|
||||||
|
build_model_alias(
|
||||||
|
"meta/llama3-8b-instruct",
|
||||||
|
CoreModelId.llama3_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta/llama3-70b-instruct",
|
||||||
|
CoreModelId.llama3_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta/llama-3.1-8b-instruct",
|
||||||
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta/llama-3.1-70b-instruct",
|
||||||
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta/llama-3.1-405b-instruct",
|
||||||
|
CoreModelId.llama3_1_405b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta/llama-3.2-1b-instruct",
|
||||||
|
CoreModelId.llama3_2_1b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta/llama-3.2-3b-instruct",
|
||||||
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta/llama-3.2-11b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta/llama-3.2-90b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
# TODO(mf): how do we handle Nemotron models?
|
||||||
|
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
def __init__(self, config: NVIDIAConfig) -> None:
|
||||||
|
# TODO(mf): filter by available models
|
||||||
|
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
|
||||||
|
|
||||||
|
print(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||||
|
|
||||||
|
if _is_nvidia_hosted(config):
|
||||||
|
if not config.api_key:
|
||||||
|
raise RuntimeError(
|
||||||
|
"API key is required for hosted NVIDIA NIM. "
|
||||||
|
"Either provide an API key or use a self-hosted NIM."
|
||||||
|
)
|
||||||
|
# elif self._config.api_key:
|
||||||
|
#
|
||||||
|
# we don't raise this warning because a user may have deployed their
|
||||||
|
# self-hosted NIM with an API key requirement.
|
||||||
|
#
|
||||||
|
# warnings.warn(
|
||||||
|
# "API key is not required for self-hosted NVIDIA NIM. "
|
||||||
|
# "Consider removing the api_key from the configuration."
|
||||||
|
# )
|
||||||
|
|
||||||
|
self._config = config
|
||||||
|
# make sure the client lives longer than any async calls
|
||||||
|
self._client = AsyncOpenAI(
|
||||||
|
base_url=f"{self._config.url}/v1",
|
||||||
|
api_key=self._config.api_key or "NO KEY",
|
||||||
|
timeout=self._config.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: InterleavedTextMedia,
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[
|
||||||
|
ToolPromptFormat
|
||||||
|
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> Union[
|
||||||
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
|
]:
|
||||||
|
if tool_prompt_format:
|
||||||
|
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
|
||||||
|
|
||||||
|
await check_health(self._config) # this raises errors
|
||||||
|
|
||||||
|
request = convert_chat_completion_request(
|
||||||
|
request=ChatCompletionRequest(
|
||||||
|
model=self.get_provider_model_id(model_id),
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
),
|
||||||
|
n=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._client.chat.completions.create(**request)
|
||||||
|
except APIConnectionError as e:
|
||||||
|
raise ConnectionError(
|
||||||
|
f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return convert_openai_chat_completion_stream(response)
|
||||||
|
else:
|
||||||
|
# we pass n=1 to get only one completion
|
||||||
|
return convert_openai_chat_completion_choice(response.choices[0])
|
581
llama_stack/providers/remote/inference/nvidia/openai_utils.py
Normal file
581
llama_stack/providers/remote/inference/nvidia/openai_utils.py
Normal file
|
@ -0,0 +1,581 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import warnings
|
||||||
|
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
CompletionMessage,
|
||||||
|
StopReason,
|
||||||
|
TokenLogProbs,
|
||||||
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
|
)
|
||||||
|
from openai import AsyncStream
|
||||||
|
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||||
|
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||||
|
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||||
|
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
|
||||||
|
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||||
|
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
|
||||||
|
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion import (
|
||||||
|
Choice as OpenAIChoice,
|
||||||
|
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||||
|
)
|
||||||
|
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||||
|
Function as OpenAIFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseEvent,
|
||||||
|
ChatCompletionResponseEventType,
|
||||||
|
ChatCompletionResponseStreamChunk,
|
||||||
|
JsonSchemaResponseFormat,
|
||||||
|
Message,
|
||||||
|
SystemMessage,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
ToolResponseMessage,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
||||||
|
"""
|
||||||
|
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
|
||||||
|
|
||||||
|
ToolDefinition:
|
||||||
|
tool_name: str | BuiltinTool
|
||||||
|
description: Optional[str]
|
||||||
|
parameters: Optional[Dict[str, ToolParamDefinition]]
|
||||||
|
|
||||||
|
ToolParamDefinition:
|
||||||
|
param_type: str
|
||||||
|
description: Optional[str]
|
||||||
|
required: Optional[bool]
|
||||||
|
default: Optional[Any]
|
||||||
|
|
||||||
|
|
||||||
|
OpenAI spec -
|
||||||
|
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_name,
|
||||||
|
"description": description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
param_name: {
|
||||||
|
"type": param_type,
|
||||||
|
"description": description,
|
||||||
|
"default": default,
|
||||||
|
},
|
||||||
|
...
|
||||||
|
},
|
||||||
|
"required": [param_name, ...],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
out = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {},
|
||||||
|
}
|
||||||
|
function = out["function"]
|
||||||
|
|
||||||
|
if isinstance(tool.tool_name, BuiltinTool):
|
||||||
|
function.update(name=tool.tool_name.value) # TODO(mf): is this sufficient?
|
||||||
|
else:
|
||||||
|
function.update(name=tool.tool_name)
|
||||||
|
|
||||||
|
if tool.description:
|
||||||
|
function.update(description=tool.description)
|
||||||
|
|
||||||
|
if tool.parameters:
|
||||||
|
parameters = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
}
|
||||||
|
properties = parameters["properties"]
|
||||||
|
required = []
|
||||||
|
for param_name, param in tool.parameters.items():
|
||||||
|
properties[param_name] = {"type": param.param_type}
|
||||||
|
if param.description:
|
||||||
|
properties[param_name].update(description=param.description)
|
||||||
|
if param.default:
|
||||||
|
properties[param_name].update(default=param.default)
|
||||||
|
if param.required:
|
||||||
|
required.append(param_name)
|
||||||
|
|
||||||
|
if required:
|
||||||
|
parameters.update(required=required)
|
||||||
|
|
||||||
|
function.update(parameters=parameters)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
||||||
|
"""
|
||||||
|
Convert a Message to an OpenAI API-compatible dictionary.
|
||||||
|
"""
|
||||||
|
# users can supply a dict instead of a Message object, we'll
|
||||||
|
# convert it to a Message object and proceed with some type safety.
|
||||||
|
if isinstance(message, dict):
|
||||||
|
if "role" not in message:
|
||||||
|
raise ValueError("role is required in message")
|
||||||
|
if message["role"] == "user":
|
||||||
|
message = UserMessage(**message)
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
message = CompletionMessage(**message)
|
||||||
|
elif message["role"] == "ipython":
|
||||||
|
message = ToolResponseMessage(**message)
|
||||||
|
elif message["role"] == "system":
|
||||||
|
message = SystemMessage(**message)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported message role: {message['role']}")
|
||||||
|
|
||||||
|
out: OpenAIChatCompletionMessage = None
|
||||||
|
if isinstance(message, UserMessage):
|
||||||
|
out = OpenAIChatCompletionUserMessage(
|
||||||
|
role="user",
|
||||||
|
content=message.content, # TODO(mf): handle image content
|
||||||
|
)
|
||||||
|
elif isinstance(message, CompletionMessage):
|
||||||
|
out = OpenAIChatCompletionAssistantMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=message.content,
|
||||||
|
tool_calls=[
|
||||||
|
OpenAIChatCompletionMessageToolCall(
|
||||||
|
id=tool.call_id,
|
||||||
|
function=OpenAIFunction(
|
||||||
|
name=tool.tool_name,
|
||||||
|
arguments=json.dumps(tool.arguments),
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
)
|
||||||
|
for tool in message.tool_calls
|
||||||
|
],
|
||||||
|
)
|
||||||
|
elif isinstance(message, ToolResponseMessage):
|
||||||
|
out = OpenAIChatCompletionToolMessage(
|
||||||
|
role="tool",
|
||||||
|
tool_call_id=message.call_id,
|
||||||
|
content=message.content,
|
||||||
|
)
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
out = OpenAIChatCompletionSystemMessage(
|
||||||
|
role="system",
|
||||||
|
content=message.content,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def convert_chat_completion_request(
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
n: int = 1,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
|
||||||
|
"""
|
||||||
|
# model -> model
|
||||||
|
# messages -> messages
|
||||||
|
# sampling_params TODO(mattf): review strategy
|
||||||
|
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
|
||||||
|
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
|
||||||
|
# strategy=top_k -> nvext.top_k = top_k
|
||||||
|
# temperature -> temperature
|
||||||
|
# top_p -> top_p
|
||||||
|
# top_k -> nvext.top_k
|
||||||
|
# max_tokens -> max_tokens
|
||||||
|
# repetition_penalty -> nvext.repetition_penalty
|
||||||
|
# response_format -> GrammarResponseFormat TODO(mf)
|
||||||
|
# response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
|
||||||
|
# tools -> tools
|
||||||
|
# tool_choice ("auto", "required") -> tool_choice
|
||||||
|
# tool_prompt_format -> TBD
|
||||||
|
# stream -> stream
|
||||||
|
# logprobs -> logprobs
|
||||||
|
|
||||||
|
if request.response_format and not isinstance(
|
||||||
|
request.response_format, JsonSchemaResponseFormat
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported response format: {request.response_format}. "
|
||||||
|
"Only JsonSchemaResponseFormat is supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
nvext = {}
|
||||||
|
payload: Dict[str, Any] = dict(
|
||||||
|
model=request.model,
|
||||||
|
messages=[_convert_message(message) for message in request.messages],
|
||||||
|
stream=request.stream,
|
||||||
|
n=n,
|
||||||
|
extra_body=dict(nvext=nvext),
|
||||||
|
extra_headers={
|
||||||
|
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.response_format:
|
||||||
|
# server bug - setting guided_json changes the behavior of response_format resulting in an error
|
||||||
|
# payload.update(response_format="json_object")
|
||||||
|
nvext.update(guided_json=request.response_format.json_schema)
|
||||||
|
|
||||||
|
if request.tools:
|
||||||
|
payload.update(
|
||||||
|
tools=[_convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||||
|
)
|
||||||
|
if request.tool_choice:
|
||||||
|
payload.update(
|
||||||
|
tool_choice=request.tool_choice.value
|
||||||
|
) # we cannot include tool_choice w/o tools, server will complain
|
||||||
|
|
||||||
|
if request.logprobs:
|
||||||
|
payload.update(logprobs=True)
|
||||||
|
payload.update(top_logprobs=request.logprobs.top_k)
|
||||||
|
|
||||||
|
if request.sampling_params:
|
||||||
|
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
|
||||||
|
|
||||||
|
if request.sampling_params.max_tokens:
|
||||||
|
payload.update(max_tokens=request.sampling_params.max_tokens)
|
||||||
|
|
||||||
|
if request.sampling_params.strategy == "top_p":
|
||||||
|
nvext.update(top_k=-1)
|
||||||
|
payload.update(top_p=request.sampling_params.top_p)
|
||||||
|
elif request.sampling_params.strategy == "top_k":
|
||||||
|
if (
|
||||||
|
request.sampling_params.top_k != -1
|
||||||
|
and request.sampling_params.top_k < 1
|
||||||
|
):
|
||||||
|
warnings.warn("top_k must be -1 or >= 1")
|
||||||
|
nvext.update(top_k=request.sampling_params.top_k)
|
||||||
|
elif request.sampling_params.strategy == "greedy":
|
||||||
|
nvext.update(top_k=-1)
|
||||||
|
payload.update(temperature=request.sampling_params.temperature)
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
||||||
|
"""
|
||||||
|
Convert an OpenAI chat completion finish_reason to a StopReason.
|
||||||
|
|
||||||
|
finish_reason: Literal["stop", "length", "tool_calls", ...]
|
||||||
|
- stop: model hit a natural stop point or a provided stop sequence
|
||||||
|
- length: maximum number of tokens specified in the request was reached
|
||||||
|
- tool_calls: model called a tool
|
||||||
|
|
||||||
|
->
|
||||||
|
|
||||||
|
class StopReason(Enum):
|
||||||
|
end_of_turn = "end_of_turn"
|
||||||
|
end_of_message = "end_of_message"
|
||||||
|
out_of_tokens = "out_of_tokens"
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO(mf): are end_of_turn and end_of_message semantics correct?
|
||||||
|
return {
|
||||||
|
"stop": StopReason.end_of_turn,
|
||||||
|
"length": StopReason.out_of_tokens,
|
||||||
|
"tool_calls": StopReason.end_of_message,
|
||||||
|
}.get(finish_reason, StopReason.end_of_turn)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_openai_tool_calls(
|
||||||
|
tool_calls: List[OpenAIChatCompletionMessageToolCall],
|
||||||
|
) -> List[ToolCall]:
|
||||||
|
"""
|
||||||
|
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.
|
||||||
|
|
||||||
|
OpenAI ChatCompletionMessageToolCall:
|
||||||
|
id: str
|
||||||
|
function: Function
|
||||||
|
type: Literal["function"]
|
||||||
|
|
||||||
|
OpenAI Function:
|
||||||
|
arguments: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
->
|
||||||
|
|
||||||
|
ToolCall:
|
||||||
|
call_id: str
|
||||||
|
tool_name: str
|
||||||
|
arguments: Dict[str, ...]
|
||||||
|
"""
|
||||||
|
if not tool_calls:
|
||||||
|
return [] # CompletionMessage tool_calls is not optional
|
||||||
|
|
||||||
|
return [
|
||||||
|
ToolCall(
|
||||||
|
call_id=call.id,
|
||||||
|
tool_name=call.function.name,
|
||||||
|
arguments=json.loads(call.function.arguments),
|
||||||
|
)
|
||||||
|
for call in tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_openai_logprobs(
|
||||||
|
logprobs: OpenAIChoiceLogprobs,
|
||||||
|
) -> Optional[List[TokenLogProbs]]:
|
||||||
|
"""
|
||||||
|
Convert an OpenAI ChoiceLogprobs into a list of TokenLogProbs.
|
||||||
|
|
||||||
|
OpenAI ChoiceLogprobs:
|
||||||
|
content: Optional[List[ChatCompletionTokenLogprob]]
|
||||||
|
|
||||||
|
OpenAI ChatCompletionTokenLogprob:
|
||||||
|
token: str
|
||||||
|
logprob: float
|
||||||
|
top_logprobs: List[TopLogprob]
|
||||||
|
|
||||||
|
OpenAI TopLogprob:
|
||||||
|
token: str
|
||||||
|
logprob: float
|
||||||
|
|
||||||
|
->
|
||||||
|
|
||||||
|
TokenLogProbs:
|
||||||
|
logprobs_by_token: Dict[str, float]
|
||||||
|
- token, logprob
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not logprobs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return [
|
||||||
|
TokenLogProbs(
|
||||||
|
logprobs_by_token={
|
||||||
|
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for content in logprobs.content
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_openai_chat_completion_choice(
|
||||||
|
choice: OpenAIChoice,
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
|
"""
|
||||||
|
Convert an OpenAI Choice into a ChatCompletionResponse.
|
||||||
|
|
||||||
|
OpenAI Choice:
|
||||||
|
message: ChatCompletionMessage
|
||||||
|
finish_reason: str
|
||||||
|
logprobs: Optional[ChoiceLogprobs]
|
||||||
|
|
||||||
|
OpenAI ChatCompletionMessage:
|
||||||
|
role: Literal["assistant"]
|
||||||
|
content: Optional[str]
|
||||||
|
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
|
||||||
|
|
||||||
|
->
|
||||||
|
|
||||||
|
ChatCompletionResponse:
|
||||||
|
completion_message: CompletionMessage
|
||||||
|
logprobs: Optional[List[TokenLogProbs]]
|
||||||
|
|
||||||
|
CompletionMessage:
|
||||||
|
role: Literal["assistant"]
|
||||||
|
content: str | ImageMedia | List[str | ImageMedia]
|
||||||
|
stop_reason: StopReason
|
||||||
|
tool_calls: List[ToolCall]
|
||||||
|
|
||||||
|
class StopReason(Enum):
|
||||||
|
end_of_turn = "end_of_turn"
|
||||||
|
end_of_message = "end_of_message"
|
||||||
|
out_of_tokens = "out_of_tokens"
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
hasattr(choice, "message") and choice.message
|
||||||
|
), "error in server response: message not found"
|
||||||
|
assert (
|
||||||
|
hasattr(choice, "finish_reason") and choice.finish_reason
|
||||||
|
), "error in server response: finish_reason not found"
|
||||||
|
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
content=choice.message.content
|
||||||
|
or "", # CompletionMessage content is not optional
|
||||||
|
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||||
|
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
|
||||||
|
),
|
||||||
|
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_openai_chat_completion_stream(
|
||||||
|
stream: AsyncStream[OpenAIChatCompletionChunk],
|
||||||
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
|
"""
|
||||||
|
Convert a stream of OpenAI chat completion chunks into a stream
|
||||||
|
of ChatCompletionResponseStreamChunk.
|
||||||
|
|
||||||
|
OpenAI ChatCompletionChunk:
|
||||||
|
choices: List[Choice]
|
||||||
|
|
||||||
|
OpenAI Choice: # different from the non-streamed Choice
|
||||||
|
delta: ChoiceDelta
|
||||||
|
finish_reason: Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]]
|
||||||
|
logprobs: Optional[ChoiceLogprobs]
|
||||||
|
|
||||||
|
OpenAI ChoiceDelta:
|
||||||
|
content: Optional[str]
|
||||||
|
role: Optional[Literal["system", "user", "assistant", "tool"]]
|
||||||
|
tool_calls: Optional[List[ChoiceDeltaToolCall]]
|
||||||
|
|
||||||
|
OpenAI ChoiceDeltaToolCall:
|
||||||
|
index: int
|
||||||
|
id: Optional[str]
|
||||||
|
function: Optional[ChoiceDeltaToolCallFunction]
|
||||||
|
type: Optional[Literal["function"]]
|
||||||
|
|
||||||
|
OpenAI ChoiceDeltaToolCallFunction:
|
||||||
|
name: Optional[str]
|
||||||
|
arguments: Optional[str]
|
||||||
|
|
||||||
|
->
|
||||||
|
|
||||||
|
ChatCompletionResponseStreamChunk:
|
||||||
|
event: ChatCompletionResponseEvent
|
||||||
|
|
||||||
|
ChatCompletionResponseEvent:
|
||||||
|
event_type: ChatCompletionResponseEventType
|
||||||
|
delta: Union[str, ToolCallDelta]
|
||||||
|
logprobs: Optional[List[TokenLogProbs]]
|
||||||
|
stop_reason: Optional[StopReason]
|
||||||
|
|
||||||
|
ChatCompletionResponseEventType:
|
||||||
|
start = "start"
|
||||||
|
progress = "progress"
|
||||||
|
complete = "complete"
|
||||||
|
|
||||||
|
ToolCallDelta:
|
||||||
|
content: Union[str, ToolCall]
|
||||||
|
parse_status: ToolCallParseStatus
|
||||||
|
|
||||||
|
ToolCall:
|
||||||
|
call_id: str
|
||||||
|
tool_name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
ToolCallParseStatus:
|
||||||
|
started = "started"
|
||||||
|
in_progress = "in_progress"
|
||||||
|
failure = "failure"
|
||||||
|
success = "success"
|
||||||
|
|
||||||
|
TokenLogProbs:
|
||||||
|
logprobs_by_token: Dict[str, float]
|
||||||
|
- token, logprob
|
||||||
|
|
||||||
|
StopReason:
|
||||||
|
end_of_turn = "end_of_turn"
|
||||||
|
end_of_message = "end_of_message"
|
||||||
|
out_of_tokens = "out_of_tokens"
|
||||||
|
"""
|
||||||
|
|
||||||
|
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
|
||||||
|
def _event_type_generator() -> (
|
||||||
|
Generator[ChatCompletionResponseEventType, None, None]
|
||||||
|
):
|
||||||
|
yield ChatCompletionResponseEventType.start
|
||||||
|
while True:
|
||||||
|
yield ChatCompletionResponseEventType.progress
|
||||||
|
|
||||||
|
event_type = _event_type_generator()
|
||||||
|
|
||||||
|
# we implement NIM specific semantics, the main difference from OpenAI
|
||||||
|
# is that tool_calls are always produced as a complete call. there is no
|
||||||
|
# intermediate / partial tool call streamed. because of this, we can
|
||||||
|
# simplify the logic and not concern outselves with parse_status of
|
||||||
|
# started/in_progress/failed. we can always assume success.
|
||||||
|
#
|
||||||
|
# a stream of ChatCompletionResponseStreamChunk consists of
|
||||||
|
# 0. a start event
|
||||||
|
# 1. zero or more progress events
|
||||||
|
# - each progress event has a delta
|
||||||
|
# - each progress event may have a stop_reason
|
||||||
|
# - each progress event may have logprobs
|
||||||
|
# - each progress event may have tool_calls
|
||||||
|
# if a progress event has tool_calls,
|
||||||
|
# it is fully formed and
|
||||||
|
# can be emitted with a parse_status of success
|
||||||
|
# 2. a complete event
|
||||||
|
|
||||||
|
stop_reason = None
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
choice = chunk.choices[0] # assuming only one choice per chunk
|
||||||
|
|
||||||
|
# we assume there's only one finish_reason in the stream
|
||||||
|
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
|
||||||
|
|
||||||
|
# if there's a tool call, emit an event for each tool in the list
|
||||||
|
# if tool call and content, emit both separately
|
||||||
|
|
||||||
|
if choice.delta.tool_calls:
|
||||||
|
# the call may have content and a tool call. ChatCompletionResponseEvent
|
||||||
|
# does not support both, so we emit the content first
|
||||||
|
if choice.delta.content:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=next(event_type),
|
||||||
|
delta=choice.delta.content,
|
||||||
|
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# it is possible to have parallel tool calls in stream, but
|
||||||
|
# ChatCompletionResponseEvent only supports one per stream
|
||||||
|
if len(choice.delta.tool_calls) > 1:
|
||||||
|
warnings.warn(
|
||||||
|
"multiple tool calls found in a single delta, using the first, ignoring the rest"
|
||||||
|
)
|
||||||
|
|
||||||
|
# NIM only produces fully formed tool calls, so we can assume success
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=next(event_type),
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
|
||||||
|
parse_status=ToolCallParseStatus.success,
|
||||||
|
),
|
||||||
|
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=next(event_type),
|
||||||
|
delta=choice.delta.content or "", # content is not optional
|
||||||
|
logprobs=_convert_openai_logprobs(choice.logprobs),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
|
delta="",
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
)
|
54
llama_stack/providers/remote/inference/nvidia/utils.py
Normal file
54
llama_stack/providers/remote/inference/nvidia/utils.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from . import NVIDIAConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||||
|
return "integrate.api.nvidia.com" in config.url
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_health(url: str) -> Tuple[bool, bool]:
|
||||||
|
"""
|
||||||
|
Query {url}/v1/health/{live,ready} to check if the server is running and ready
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): URL of the server
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, bool]: (is_live, is_ready)
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
live = await client.get(f"{url}/v1/health/live")
|
||||||
|
ready = await client.get(f"{url}/v1/health/ready")
|
||||||
|
return live.status_code == 200, ready.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
async def check_health(config: NVIDIAConfig) -> None:
|
||||||
|
"""
|
||||||
|
Check if the server is running and ready
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): URL of the server
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the server is not running or ready
|
||||||
|
"""
|
||||||
|
if not _is_nvidia_hosted(config):
|
||||||
|
print("Checking NVIDIA NIM health...")
|
||||||
|
try:
|
||||||
|
is_live, is_ready = await _get_health(config.url)
|
||||||
|
if not is_live:
|
||||||
|
raise ConnectionError("NVIDIA NIM is not running")
|
||||||
|
if not is_ready:
|
||||||
|
raise ConnectionError("NVIDIA NIM is not ready")
|
||||||
|
# TODO(mf): should we wait for the server to be ready?
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e
|
|
@ -6,6 +6,8 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
|
||||||
from .fixtures import INFERENCE_FIXTURES
|
from .fixtures import INFERENCE_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,11 +69,12 @@ def pytest_generate_tests(metafunc):
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
if "inference_stack" in metafunc.fixturenames:
|
if "inference_stack" in metafunc.fixturenames:
|
||||||
metafunc.parametrize(
|
fixtures = INFERENCE_FIXTURES
|
||||||
"inference_stack",
|
if filtered_stacks := get_provider_fixture_overrides(
|
||||||
[
|
metafunc.config,
|
||||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
{
|
||||||
for fixture_name in INFERENCE_FIXTURES
|
"inference": INFERENCE_FIXTURES,
|
||||||
],
|
},
|
||||||
indirect=True,
|
):
|
||||||
)
|
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
||||||
|
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
||||||
|
|
|
@ -18,6 +18,7 @@ from llama_stack.providers.inline.inference.meta_reference import (
|
||||||
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||||
|
@ -142,6 +143,19 @@ def inference_bedrock() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def inference_nvidia() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_type="remote::nvidia",
|
||||||
|
config=NVIDIAConfig().model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_model_short_name(model_name: str) -> str:
|
def get_model_short_name(model_name: str) -> str:
|
||||||
"""Convert model name to a short test identifier.
|
"""Convert model name to a short test identifier.
|
||||||
|
|
||||||
|
@ -175,6 +189,7 @@ INFERENCE_FIXTURES = [
|
||||||
"vllm_remote",
|
"vllm_remote",
|
||||||
"remote",
|
"remote",
|
||||||
"bedrock",
|
"bedrock",
|
||||||
|
"nvidia",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -198,6 +198,7 @@ class TestInference:
|
||||||
"remote::fireworks",
|
"remote::fireworks",
|
||||||
"remote::tgi",
|
"remote::tgi",
|
||||||
"remote::together",
|
"remote::together",
|
||||||
|
"remote::nvidia",
|
||||||
):
|
):
|
||||||
pytest.skip("Other inference providers don't support structured output yet")
|
pytest.skip("Other inference providers don't support structured output yet")
|
||||||
|
|
||||||
|
@ -361,7 +362,10 @@ class TestInference:
|
||||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||||
)
|
)
|
||||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||||
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
if not isinstance(
|
||||||
|
first.event.delta.content, ToolCall
|
||||||
|
): # first chunk may contain entire call
|
||||||
|
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
||||||
|
|
||||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||||
# assert last.event.stop_reason == expected_stop_reason
|
# assert last.event.stop_reason == expected_stop_reason
|
||||||
|
|
|
@ -29,7 +29,6 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli
|
||||||
return ModelAlias(
|
return ModelAlias(
|
||||||
provider_model_id=provider_model_id,
|
provider_model_id=provider_model_id,
|
||||||
aliases=[
|
aliases=[
|
||||||
model_descriptor,
|
|
||||||
get_huggingface_repo(model_descriptor),
|
get_huggingface_repo(model_descriptor),
|
||||||
],
|
],
|
||||||
llama_model=model_descriptor,
|
llama_model=model_descriptor,
|
||||||
|
@ -57,6 +56,10 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
self.alias_to_provider_id_map[alias_obj.provider_model_id] = (
|
self.alias_to_provider_id_map[alias_obj.provider_model_id] = (
|
||||||
alias_obj.provider_model_id
|
alias_obj.provider_model_id
|
||||||
)
|
)
|
||||||
|
# ensure we can go from llama model to provider model id
|
||||||
|
self.alias_to_provider_id_map[alias_obj.llama_model] = (
|
||||||
|
alias_obj.provider_model_id
|
||||||
|
)
|
||||||
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = (
|
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = (
|
||||||
alias_obj.llama_model
|
alias_obj.llama_model
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue