From 2dd8c4bcb6216daebeaafac282add176cb7b5047 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 22 Oct 2024 14:31:11 -0400 Subject: [PATCH] add NVIDIA NIM inference adapter --- .../adapters/inference/nvidia/__init__.py | 18 + .../adapters/inference/nvidia/_config.py | 52 +++ .../adapters/inference/nvidia/_nvidia.py | 176 ++++++++++ .../adapters/inference/nvidia/_utils.py | 328 ++++++++++++++++++ llama_stack/providers/registry/inference.py | 9 + tests/nvidia/README.md | 26 ++ tests/nvidia/integration/conftest.py | 67 ++++ tests/nvidia/integration/test_inference.py | 117 +++++++ tests/nvidia/unit/conftest.py | 73 ++++ tests/nvidia/unit/test_chat_completion.py | 203 +++++++++++ tests/nvidia/unit/test_health.py | 35 ++ tests/nvidia/unit/test_import.py | 11 + 12 files changed, 1115 insertions(+) create mode 100644 llama_stack/providers/adapters/inference/nvidia/__init__.py create mode 100644 llama_stack/providers/adapters/inference/nvidia/_config.py create mode 100644 llama_stack/providers/adapters/inference/nvidia/_nvidia.py create mode 100644 llama_stack/providers/adapters/inference/nvidia/_utils.py create mode 100644 tests/nvidia/README.md create mode 100644 tests/nvidia/integration/conftest.py create mode 100644 tests/nvidia/integration/test_inference.py create mode 100644 tests/nvidia/unit/conftest.py create mode 100644 tests/nvidia/unit/test_chat_completion.py create mode 100644 tests/nvidia/unit/test_health.py create mode 100644 tests/nvidia/unit/test_import.py diff --git a/llama_stack/providers/adapters/inference/nvidia/__init__.py b/llama_stack/providers/adapters/inference/nvidia/__init__.py new file mode 100644 index 000000000..63b466933 --- /dev/null +++ b/llama_stack/providers/adapters/inference/nvidia/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from ._config import NVIDIAConfig +from ._nvidia import NVIDIAInferenceAdapter + + +async def get_adapter_impl(config: NVIDIAConfig, _deps) -> NVIDIAInferenceAdapter: + if not isinstance(config, NVIDIAConfig): + raise RuntimeError(f"Unexpected config type: {type(config)}") + adapter = NVIDIAInferenceAdapter(config) + return adapter + + +__all__ = ["get_adapter_impl", "NVIDIAConfig"] diff --git a/llama_stack/providers/adapters/inference/nvidia/_config.py b/llama_stack/providers/adapters/inference/nvidia/_config.py new file mode 100644 index 000000000..46ac3fa5b --- /dev/null +++ b/llama_stack/providers/adapters/inference/nvidia/_config.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class NVIDIAConfig(BaseModel): + """ + Configuration for the NVIDIA NIM inference endpoint. + + Attributes: + base_url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000 + api_key (str): The access key for the hosted NIM endpoints + + There are two ways to access NVIDIA NIMs - + 0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com + 1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure + + By default the configuration is set to use the hosted APIs. This requires + an API key which can be obtained from https://ngc.nvidia.com/. + + By default the configuration will attempt to read the NVIDIA_API_KEY environment + variable to set the api_key. Please do not put your API key in code. + + If you are using a self-hosted NVIDIA NIM, you can set the base_url to the + URL of your running NVIDIA NIM and do not need to set the api_key. + """ + + base_url: str = Field( + default="https://integrate.api.nvidia.com", + description="A base url for accessing the NVIDIA NIM", + ) + api_key: Optional[str] = Field( + default_factory=lambda: os.getenv("NVIDIA_API_KEY"), + description="The NVIDIA API key, only needed of using the hosted service", + ) + timeout: int = Field( + default=60, + description="Timeout for the HTTP requests", + ) + + @property + def is_hosted(self) -> bool: + return "integrate.api.nvidia.com" in self.base_url diff --git a/llama_stack/providers/adapters/inference/nvidia/_nvidia.py b/llama_stack/providers/adapters/inference/nvidia/_nvidia.py new file mode 100644 index 000000000..621e3e0db --- /dev/null +++ b/llama_stack/providers/adapters/inference/nvidia/_nvidia.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import warnings +from typing import Dict, List, Optional, Union + +import httpx +from llama_models.datatypes import SamplingParams +from llama_models.llama3.api.datatypes import ( + InterleavedTextMedia, + Message, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_models.sku_list import CoreModelId + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, + CompletionResponse, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + LogProbConfig, + ModelDef, + ResponseFormat, +) + +from ._config import NVIDIAConfig +from ._utils import check_health, convert_chat_completion_request, parse_completion + +SUPPORTED_MODELS: Dict[CoreModelId, str] = { + CoreModelId.llama3_8b_instruct: "meta/llama3-8b-instruct", + CoreModelId.llama3_70b_instruct: "meta/llama3-70b-instruct", + CoreModelId.llama3_1_8b_instruct: "meta/llama-3.1-8b-instruct", + CoreModelId.llama3_1_70b_instruct: "meta/llama-3.1-70b-instruct", + CoreModelId.llama3_1_405b_instruct: "meta/llama-3.1-405b-instruct", + # TODO(mf): how do we handle Nemotron models? + # "Llama3.1-Nemotron-51B-Instruct": "meta/llama-3.1-nemotron-51b-instruct", + CoreModelId.llama3_2_1b_instruct: "meta/llama-3.2-1b-instruct", + CoreModelId.llama3_2_3b_instruct: "meta/llama-3.2-3b-instruct", + CoreModelId.llama3_2_11b_vision_instruct: "meta/llama-3.2-11b-vision-instruct", + CoreModelId.llama3_2_90b_vision_instruct: "meta/llama-3.2-90b-vision-instruct", +} + + +class NVIDIAInferenceAdapter(Inference): + def __init__(self, config: NVIDIAConfig) -> None: + + print(f"Initializing NVIDIAInferenceAdapter({config.base_url})...") + + if config.is_hosted: + if not config.api_key: + raise RuntimeError( + "API key is required for hosted NVIDIA NIM. " + "Either provide an API key or use a self-hosted NIM." + ) + # elif self._config.api_key: + # + # we don't raise this warning because a user may have deployed their + # self-hosted NIM with an API key requirement. + # + # warnings.warn( + # "API key is not required for self-hosted NVIDIA NIM. " + # "Consider removing the api_key from the configuration." + # ) + + self._config = config + + @property + def _headers(self) -> dict: + return { + b"User-Agent": b"llama-stack: nvidia-inference-adapter", + **( + {b"Authorization": f"Bearer {self._config.api_key}"} + if self._config.api_key + else {} + ), + } + + async def list_models(self) -> List[ModelDef]: + # TODO(mf): filter by available models + return [ + ModelDef(identifier=model, llama_model=id_) + for model, id_ in SUPPORTED_MODELS.items() + ] + + def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + raise NotImplementedError() + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() + + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ + ToolPromptFormat + ] = None, # API default is ToolPromptFormat.json, we default to None to detect user input + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: + if tool_prompt_format: + warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring") + + if stream: + raise ValueError("Streamed completions are not supported") + + await check_health(self._config) # this raises errors + + request = ChatCompletionRequest( + model=SUPPORTED_MODELS[CoreModelId(model)], + messages=messages, + sampling_params=sampling_params, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + + async with httpx.AsyncClient(timeout=self._config.timeout) as client: + try: + response = await client.post( + f"{self._config.base_url}/v1/chat/completions", + headers=self._headers, + json=convert_chat_completion_request(request, n=1), + ) + except httpx.ReadTimeout as e: + raise TimeoutError( + f"Request timed out. timeout set to {self._config.timeout}. Use `llama stack configure ...` to adjust it." + ) from e + + if response.status_code == 401: + raise PermissionError( + "Unauthorized. Please check your API key, reconfigure, and try again." + ) + + if response.status_code == 400: + raise ValueError( + f"Bad request. Please check the request and try again. Detail: {response.text}" + ) + + if response.status_code == 404: + raise ValueError( + "Model not found. Please check the model name and try again." + ) + + assert ( + response.status_code == 200 + ), f"Failed to get completion: {response.text}" + + # we pass n=1 to get only one completion + return parse_completion(response.json()["choices"][0]) diff --git a/llama_stack/providers/adapters/inference/nvidia/_utils.py b/llama_stack/providers/adapters/inference/nvidia/_utils.py new file mode 100644 index 000000000..6b9075050 --- /dev/null +++ b/llama_stack/providers/adapters/inference/nvidia/_utils.py @@ -0,0 +1,328 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import warnings +from typing import Any, Dict, List, Optional, Tuple + +import httpx +from llama_models.llama3.api.datatypes import ( + CompletionMessage, + StopReason, + TokenLogProbs, + ToolCall, +) + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + Message, +) + +from ._config import NVIDIAConfig + + +def convert_message(message: Message) -> dict: + """ + Convert a Message to an OpenAI API-compatible dictionary. + """ + out_dict = message.dict() + # Llama Stack uses role="ipython" for tool call messages, OpenAI uses "tool" + if out_dict["role"] == "ipython": + out_dict.update(role="tool") + + if "stop_reason" in out_dict: + out_dict.update(stop_reason=out_dict["stop_reason"].value) + + # TODO(mf): tool_calls + + return out_dict + + +async def _get_health(url: str) -> Tuple[bool, bool]: + """ + Query {url}/v1/health/{live,ready} to check if the server is running and ready + + Args: + url (str): URL of the server + + Returns: + Tuple[bool, bool]: (is_live, is_ready) + """ + async with httpx.AsyncClient() as client: + live = await client.get(f"{url}/v1/health/live") + ready = await client.get(f"{url}/v1/health/ready") + return live.status_code == 200, ready.status_code == 200 + + +async def check_health(config: NVIDIAConfig) -> None: + """ + Check if the server is running and ready + + Args: + url (str): URL of the server + + Raises: + RuntimeError: If the server is not running or ready + """ + if not config.is_hosted: + print("Checking NVIDIA NIM health...") + try: + is_live, is_ready = await _get_health(config.base_url) + if not is_live: + raise ConnectionError("NVIDIA NIM is not running") + if not is_ready: + raise ConnectionError("NVIDIA NIM is not ready") + # TODO(mf): should we wait for the server to be ready? + except httpx.ConnectError as e: + raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e + + +def convert_chat_completion_request( + request: ChatCompletionRequest, + n: int = 1, +) -> dict: + """ + Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. + """ + # model -> model + # messages -> messages + # sampling_params TODO(mattf): review strategy + # strategy=greedy -> nvext.top_k = -1, temperature = temperature + # strategy=top_p -> nvext.top_k = -1, top_p = top_p + # strategy=top_k -> nvext.top_k = top_k + # temperature -> temperature + # top_p -> top_p + # top_k -> nvext.top_k + # max_tokens -> max_tokens + # repetition_penalty -> nvext.repetition_penalty + # tools -> tools + # tool_choice ("auto", "required") -> tool_choice + # tool_prompt_format -> TBD + # stream -> stream + # logprobs -> logprobs + + print(f"sampling_params: {request.sampling_params}") + + payload: Dict[str, Any] = dict( + model=request.model, + messages=[convert_message(message) for message in request.messages], + stream=request.stream, + nvext={}, + n=n, + ) + nvext = payload["nvext"] + + if request.tools: + payload.update(tools=request.tools) + if request.tool_choice: + payload.update( + tool_choice=request.tool_choice.value + ) # we cannot include tool_choice w/o tools, server will complain + + if request.logprobs: + payload.update(logprobs=True) + payload.update(top_logprobs=request.logprobs.top_k) + + if request.sampling_params: + nvext.update(repetition_penalty=request.sampling_params.repetition_penalty) + + if request.sampling_params.max_tokens: + payload.update(max_tokens=request.sampling_params.max_tokens) + + if request.sampling_params.strategy == "top_p": + nvext.update(top_k=-1) + payload.update(top_p=request.sampling_params.top_p) + elif request.sampling_params.strategy == "top_k": + if ( + request.sampling_params.top_k != -1 + and request.sampling_params.top_k < 1 + ): + warnings.warn("top_k must be -1 or >= 1") + nvext.update(top_k=request.sampling_params.top_k) + elif request.sampling_params.strategy == "greedy": + nvext.update(top_k=-1) + payload.update(temperature=request.sampling_params.temperature) + + return payload + + +def _parse_content(completion: dict) -> str: + """ + Get the content from an OpenAI completion response. + + OpenAI completion response format - + { + ... + "message": {"role": "assistant", "content": ..., ...}, + ... + } + """ + # content is nullable in the OpenAI response, common for tool calls + return completion["message"]["content"] or "" + + +def _parse_stop_reason(completion: dict) -> StopReason: + """ + Get the StopReason from an OpenAI completion response. + + OpenAI completion response format - + { + ... + "finish_reason": "length" or "stop" or "tool_calls", + ... + } + """ + + # StopReason options are end_of_turn, end_of_message, out_of_tokens + # TODO(mf): is end_of_turn and end_of_message usage correct? + stop_reason = StopReason.end_of_turn + if completion["finish_reason"] == "length": + stop_reason = StopReason.out_of_tokens + elif completion["finish_reason"] == "stop": + stop_reason = StopReason.end_of_message + elif completion["finish_reason"] == "tool_calls": + stop_reason = StopReason.end_of_turn + return stop_reason + + +def _parse_tool_calls(completion: dict) -> List[ToolCall]: + """ + Get the tool calls from an OpenAI completion response. + + OpenAI completion response format - + { + ..., + "message": { + ..., + "tool_calls": [ + { + "id": X, + "type": "function", + "function": { + "name": Y, + "arguments": Z, + }, + }* + ], + }, + } + -> + [ + ToolCall(call_id=X, tool_name=Y, arguments=Z), + ... + ] + """ + tool_calls = [] + if "tool_calls" in completion["message"]: + assert isinstance( + completion["message"]["tool_calls"], list + ), "error in server response: tool_calls not a list" + for call in completion["message"]["tool_calls"]: + assert "id" in call, "error in server response: tool call id not found" + assert ( + "function" in call + ), "error in server response: tool call function not found" + assert ( + "name" in call["function"] + ), "error in server response: tool call function name not found" + assert ( + "arguments" in call["function"] + ), "error in server response: tool call function arguments not found" + tool_calls.append( + ToolCall( + call_id=call["id"], + tool_name=call["function"]["name"], + arguments=call["function"]["arguments"], + ) + ) + + return tool_calls + + +def _parse_logprobs(completion: dict) -> Optional[List[TokenLogProbs]]: + """ + Extract logprobs from OpenAI as a list of TokenLogProbs. + + OpenAI completion response format - + { + ... + "logprobs": { + content: [ + { + ..., + top_logprobs: [{token: X, logprob: Y, bytes: [...]}+] + }+ + ] + }, + ... + } + -> + [ + TokenLogProbs( + logprobs_by_token={X: Y, ...} + ), + ... + ] + """ + if not (logprobs := completion.get("logprobs")): + return None + + return [ + TokenLogProbs( + logprobs_by_token={ + logprobs["token"]: logprobs["logprob"] + for logprobs in content["top_logprobs"] + } + ) + for content in logprobs["content"] + ] + + +def parse_completion( + completion: dict, +) -> ChatCompletionResponse: + """ + Parse an OpenAI completion response into a CompletionMessage and logprobs. + + OpenAI completion response format - + { + "message": { + "role": "assistant", + "content": ..., + "tool_calls": [ + { + ... + "id": ..., + "function": { + "name": ..., + "arguments": ..., + }, + }* + ]?, + "finish_reason": ..., + "logprobs": { + "content": [ + { + ..., + "top_logprobs": [{"token": ..., "logprob": ..., ...}+] + }+ + ] + }? + } + """ + assert "message" in completion, "error in server response: message not found" + assert ( + "finish_reason" in completion + ), "error in server response: finish_reason not found" + + return ChatCompletionResponse( + completion_message=CompletionMessage( + content=_parse_content(completion), + stop_reason=_parse_stop_reason(completion), + tool_calls=_parse_tool_calls(completion), + ), + logprobs=_parse_logprobs(completion), + ) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 88265f1b4..18397a08d 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -140,6 +140,15 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.inference.databricks.DatabricksImplConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="nvidia", + pip_packages=[], # TODO(mf): need to specify httpx if it's already a llama-stack dep? + module="llama_stack.providers.adapters.inference.nvidia", + config_class="llama_stack.providers.adapters.inference.nvidia.NVIDIAConfig", + ), + ), InlineProviderSpec( api=Api.inference, provider_type="vllm", diff --git a/tests/nvidia/README.md b/tests/nvidia/README.md new file mode 100644 index 000000000..939a998d7 --- /dev/null +++ b/tests/nvidia/README.md @@ -0,0 +1,26 @@ +# NVIDIA tests + +## Running tests + +**Install the required dependencies:** + ```bash + pip install pytest pytest-asyncio pytest-httpx + ``` + +There are three modes for testing: + +1. Unit tests - this mode checks the provider functionality and does not require a network connection or running distribution + + ```bash + pytest tests/nvidia/unit + ``` + +2. Integration tests against hosted preview APIs - this mode checks the provider functionality against a live system and requires an API key. Get an API key by 0. going to https://build.nvidia.com, 1. selecting a Llama model, e.g. https://build.nvidia.com/meta/llama-3_1-8b-instruct, and 2. clicking "Get API Key". Store the API key in the `NVIDIA_API_KEY` environment variable. + + ```bash + export NVIDIA_API_KEY=... + + pytest tests/nvidia/integration --base-url https://integrate.api.nvidia.com + ``` + +3. Integration tests against a running distribution - this mode checks the provider functionality in the context of a running distribution. This involves running a local NIM, see https://build.nvidia.com/meta/llama-3_1-8b-instruct?snippet_tab=Docker, and creating & configuring a distribution to use it. Details to come. diff --git a/tests/nvidia/integration/conftest.py b/tests/nvidia/integration/conftest.py new file mode 100644 index 000000000..0691b7453 --- /dev/null +++ b/tests/nvidia/integration/conftest.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import pytest + +from llama_stack.apis.inference import Inference +from llama_stack.providers.adapters.inference.nvidia import ( + get_adapter_impl, + NVIDIAConfig, +) + + +def pytest_collection_modifyitems(config, items): + """ + Skip all integration tests if NVIDIA_API_KEY is not set and --base-url + includes "https://integrate.api.nvidia.com". It is needed to access the + hosted preview APIs. + """ + if "integrate.api.nvidia.com" in config.getoption( + "--base-url" + ) and not os.environ.get("NVIDIA_API_KEY"): + skip_nvidia = pytest.mark.skip( + reason="NVIDIA_API_KEY environment variable must be set to access integrate.api.nvidia.com" + ) + for item in items: + item.add_marker(skip_nvidia) + + +def pytest_addoption(parser): + parser.addoption( + "--base-url", + action="store", + default="http://localhost:8000", + help="Base URL for the tests", + ) + parser.addoption( + "--model", + action="store", + default="Llama-3-8B-Instruct", + help="Model option for the tests", + ) + + +@pytest.fixture +def base_url(request): + return request.config.getoption("--base-url") + + +@pytest.fixture +def model(request): + return request.config.getoption("--model") + + +@pytest.fixture +def client(base_url: str) -> Inference: + return get_adapter_impl( + NVIDIAConfig( + base_url=base_url, + api_key=os.environ.get("NVIDIA_API_KEY"), + ), + {}, + ) diff --git a/tests/nvidia/integration/test_inference.py b/tests/nvidia/integration/test_inference.py new file mode 100644 index 000000000..2e7b33e4f --- /dev/null +++ b/tests/nvidia/integration/test_inference.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import itertools +from typing import Generator, List, Tuple + +import pytest + +from llama_stack.apis.inference import ( + ChatCompletionResponse, + CompletionMessage, + Inference, + Message, + StopReason, + SystemMessage, + ToolResponseMessage, + UserMessage, +) +from llama_stack.providers.adapters.inference.nvidia import ( + get_adapter_impl, + NVIDIAConfig, +) + +pytestmark = pytest.mark.asyncio + + +# TODO(mf): test bad creds raises PermissionError +# TODO(mf): test bad params, e.g. max_tokens=0 raises ValidationError +# TODO(mf): test bad model name raises ValueError +# TODO(mf): test short timeout raises TimeoutError +# TODO(mf): new file, test cli model listing +# TODO(mf): test streaming +# TODO(mf): test tool calls w/ tool_choice + + +def message_combinations( + length: int, +) -> Generator[Tuple[List[Message], str], None, None]: + """ + Generate all possible combinations of message types of given length. + """ + message_types = [ + UserMessage, + SystemMessage, + ToolResponseMessage, + CompletionMessage, + ] + for count in range(1, length + 1): + for combo in itertools.product(message_types, repeat=count): + messages = [] + for i, msg in enumerate(combo): + if msg == ToolResponseMessage: + messages.append( + msg( + content=f"Message {i + 1}", + call_id=f"call_{i + 1}", + tool_name=f"tool_{i + 1}", + ) + ) + elif msg == CompletionMessage: + messages.append( + msg(content=f"Message {i + 1}", stop_reason="end_of_message") + ) + else: + messages.append(msg(content=f"Message {i + 1}")) + id_str = "-".join([msg.__name__ for msg in combo]) + yield messages, id_str + + +@pytest.mark.parametrize("combo", message_combinations(3), ids=lambda x: x[1]) +async def test_chat_completion_messages( + client: Inference, + model: str, + combo: Tuple[List[Message], str], +): + """ + Test the chat completion endpoint with different message combinations. + """ + client = await client + messages, _ = combo + + response = await client.chat_completion( + model=model, + messages=messages, + stream=False, + ) + + assert isinstance(response, ChatCompletionResponse) + assert isinstance(response.completion_message.content, str) + # we're not testing accuracy, so no assertions on the result.completion_message.content + assert response.completion_message.role == "assistant" + assert isinstance(response.completion_message.stop_reason, StopReason) + assert response.completion_message.tool_calls == [] + + +async def test_bad_base_url( + model: str, +): + """ + Test that a bad base_url raises a ConnectionError. + """ + client = await get_adapter_impl( + NVIDIAConfig( + base_url="http://localhost:32123", + ), + {}, + ) + + with pytest.raises(ConnectionError): + await client.chat_completion( + model=model, + messages=[UserMessage(content="Hello")], + stream=False, + ) diff --git a/tests/nvidia/unit/conftest.py b/tests/nvidia/unit/conftest.py new file mode 100644 index 000000000..cdc0c50d7 --- /dev/null +++ b/tests/nvidia/unit/conftest.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import pytest + +from llama_stack.apis.inference import Inference +from llama_stack.providers.adapters.inference.nvidia import ( + get_adapter_impl, + NVIDIAConfig, +) +from pytest_httpx import HTTPXMock + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def base_url(): + return "http://endpoint.mocked" + + +@pytest.fixture +def client(base_url: str) -> Inference: + return get_adapter_impl( + NVIDIAConfig( + base_url=base_url, + api_key=os.environ.get("NVIDIA_API_KEY"), + ), + {}, + ) + + +@pytest.fixture +def mock_health( + httpx_mock: HTTPXMock, + base_url: str, +) -> HTTPXMock: + for path in [ + "/v1/health/live", + "/v1/health/ready", + ]: + httpx_mock.add_response( + url=f"{base_url}{path}", + status_code=200, + ) + return httpx_mock + + +@pytest.fixture +def mock_chat_completion(httpx_mock: HTTPXMock, base_url: str) -> HTTPXMock: + httpx_mock.add_response( + url=f"{base_url}/v1/chat/completions", + json={ + "id": "mock-id", + "created": 1234567890, + "object": "chat.completion", + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "WORKED"}, + "finish_reason": "length", + } + ], + }, + status_code=200, + ) + + return httpx_mock diff --git a/tests/nvidia/unit/test_chat_completion.py b/tests/nvidia/unit/test_chat_completion.py new file mode 100644 index 000000000..1608ad39a --- /dev/null +++ b/tests/nvidia/unit/test_chat_completion.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from llama_models.llama3.api.datatypes import TokenLogProbs, ToolCall + +from llama_stack.apis.inference import Inference +from pytest_httpx import HTTPXMock + +pytestmark = pytest.mark.asyncio + + +async def test_content( + mock_health: HTTPXMock, + httpx_mock: HTTPXMock, + client: Inference, + base_url: str, +) -> None: + """ + Test that response content makes it through to the completion message. + """ + httpx_mock.add_response( + url=f"{base_url}/v1/chat/completions", + json={ + "id": "mock-id", + "created": 1234567890, + "object": "chat.completion", + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "RESPONSE"}, + "finish_reason": "length", + } + ], + }, + status_code=200, + ) + + client = await client + + response = await client.chat_completion( + model="Llama-3-8B-Instruct", + messages=[{"role": "user", "content": "BOGUS"}], + stream=False, + ) + assert response.completion_message.content == "RESPONSE" + + +async def test_logprobs( + mock_health: HTTPXMock, + httpx_mock: HTTPXMock, + client: Inference, + base_url: str, +) -> None: + """ + Test that logprobs are parsed correctly. + """ + httpx_mock.add_response( + url=f"{base_url}/v1/chat/completions", + json={ + "id": "mock-id", + "object": "chat.completion", + "created": 1234567890, + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Hello there"}, + "logprobs": { + "content": [ + { + "token": "Hello", + "logprob": -0.1, + "bytes": [72, 101, 108, 108, 111], + "top_logprobs": [ + {"token": "Hello", "logprob": -0.1}, + {"token": "Hi", "logprob": -1.2}, + {"token": "Greetings", "logprob": -2.1}, + ], + }, + { + "token": "there", + "logprob": -0.2, + "bytes": [116, 104, 101, 114, 101], + "top_logprobs": [ + {"token": "there", "logprob": -0.2}, + {"token": "here", "logprob": -1.3}, + {"token": "where", "logprob": -2.2}, + ], + }, + ] + }, + "finish_reason": "length", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + }, + status_code=200, + ) + + client = await client + + response = await client.chat_completion( + model="Llama-3-8B-Instruct", + messages=[{"role": "user", "content": "Hello"}], + logprobs={"top_k": 3}, + stream=False, + ) + + assert response.logprobs == [ + TokenLogProbs( + logprobs_by_token={ + "Hello": -0.1, + "Hi": -1.2, + "Greetings": -2.1, + } + ), + TokenLogProbs( + logprobs_by_token={ + "there": -0.2, + "here": -1.3, + "where": -2.2, + } + ), + ] + + +async def test_tools( + mock_health: HTTPXMock, + httpx_mock: HTTPXMock, + client: Inference, + base_url: str, +) -> None: + """ + Test that tools are passed correctly. + """ + httpx_mock.add_response( + url=f"{base_url}/v1/chat/completions", + json={ + "id": "mock-id", + "object": "chat.completion", + "created": 1234567890, + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "tool-id", + "type": "function", + "function": { + "name": "magic", + "arguments": {"input": 3}, + }, + }, + { + "id": "tool-id!", + "type": "function", + "function": { + "name": "magic!", + "arguments": {"input": 42}, + }, + }, + ], + }, + "logprobs": None, + "finish_reason": "tool_calls", + } + ], + }, + status_code=200, + ) + + client = await client + + response = await client.chat_completion( + model="Llama-3-8B-Instruct", + messages=[{"role": "user", "content": "Hello"}], + stream=False, + ) + + assert response.completion_message.tool_calls == [ + ToolCall( + call_id="tool-id", + tool_name="magic", + arguments={"input": 3}, + ), + ToolCall( + call_id="tool-id!", + tool_name="magic!", + arguments={"input": 42}, + ), + ] + + +# TODO(mf): test stream=True for each case diff --git a/tests/nvidia/unit/test_health.py b/tests/nvidia/unit/test_health.py new file mode 100644 index 000000000..0e3d146a3 --- /dev/null +++ b/tests/nvidia/unit/test_health.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.apis.inference import Inference +from pytest_httpx import HTTPXMock + +pytestmark = pytest.mark.asyncio + + +async def test_chat_completion( + mock_health: HTTPXMock, + mock_chat_completion: HTTPXMock, + client: Inference, + base_url: str, +) -> None: + """ + Test that health endpoints are checked when chat_completion is called. + """ + client = await client + + await client.chat_completion( + model="Llama-3-8B-Instruct", + messages=[{"role": "user", "content": "BOGUS"}], + stream=False, + ) + + +# TODO(mf): test stream=True for each case +# TODO(mf): test completion +# TODO(mf): test embedding diff --git a/tests/nvidia/unit/test_import.py b/tests/nvidia/unit/test_import.py new file mode 100644 index 000000000..87e667239 --- /dev/null +++ b/tests/nvidia/unit/test_import.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.adapters.inference.nvidia import __all__ + + +def test_import(): + assert set(__all__) == {"get_adapter_impl", "NVIDIAConfig"}