add NVIDIA NIM inference adapter

This commit is contained in:
Matthew Farrellee 2024-10-22 14:31:11 -04:00
parent ac93dd89cf
commit 2dd8c4bcb6
12 changed files with 1115 additions and 0 deletions

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from ._config import NVIDIAConfig
from ._nvidia import NVIDIAInferenceAdapter
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> NVIDIAInferenceAdapter:
if not isinstance(config, NVIDIAConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = NVIDIAInferenceAdapter(config)
return adapter
__all__ = ["get_adapter_impl", "NVIDIAConfig"]

View file

@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class NVIDIAConfig(BaseModel):
"""
Configuration for the NVIDIA NIM inference endpoint.
Attributes:
base_url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints
There are two ways to access NVIDIA NIMs -
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
By default the configuration is set to use the hosted APIs. This requires
an API key which can be obtained from https://ngc.nvidia.com/.
By default the configuration will attempt to read the NVIDIA_API_KEY environment
variable to set the api_key. Please do not put your API key in code.
If you are using a self-hosted NVIDIA NIM, you can set the base_url to the
URL of your running NVIDIA NIM and do not need to set the api_key.
"""
base_url: str = Field(
default="https://integrate.api.nvidia.com",
description="A base url for accessing the NVIDIA NIM",
)
api_key: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",
)
@property
def is_hosted(self) -> bool:
return "integrate.api.nvidia.com" in self.base_url

View file

@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from typing import Dict, List, Optional, Union
import httpx
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_models.sku_list import CoreModelId
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
LogProbConfig,
ModelDef,
ResponseFormat,
)
from ._config import NVIDIAConfig
from ._utils import check_health, convert_chat_completion_request, parse_completion
SUPPORTED_MODELS: Dict[CoreModelId, str] = {
CoreModelId.llama3_8b_instruct: "meta/llama3-8b-instruct",
CoreModelId.llama3_70b_instruct: "meta/llama3-70b-instruct",
CoreModelId.llama3_1_8b_instruct: "meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_70b_instruct: "meta/llama-3.1-70b-instruct",
CoreModelId.llama3_1_405b_instruct: "meta/llama-3.1-405b-instruct",
# TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct": "meta/llama-3.1-nemotron-51b-instruct",
CoreModelId.llama3_2_1b_instruct: "meta/llama-3.2-1b-instruct",
CoreModelId.llama3_2_3b_instruct: "meta/llama-3.2-3b-instruct",
CoreModelId.llama3_2_11b_vision_instruct: "meta/llama-3.2-11b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct: "meta/llama-3.2-90b-vision-instruct",
}
class NVIDIAInferenceAdapter(Inference):
def __init__(self, config: NVIDIAConfig) -> None:
print(f"Initializing NVIDIAInferenceAdapter({config.base_url})...")
if config.is_hosted:
if not config.api_key:
raise RuntimeError(
"API key is required for hosted NVIDIA NIM. "
"Either provide an API key or use a self-hosted NIM."
)
# elif self._config.api_key:
#
# we don't raise this warning because a user may have deployed their
# self-hosted NIM with an API key requirement.
#
# warnings.warn(
# "API key is not required for self-hosted NVIDIA NIM. "
# "Consider removing the api_key from the configuration."
# )
self._config = config
@property
def _headers(self) -> dict:
return {
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
**(
{b"Authorization": f"Bearer {self._config.api_key}"}
if self._config.api_key
else {}
),
}
async def list_models(self) -> List[ModelDef]:
# TODO(mf): filter by available models
return [
ModelDef(identifier=model, llama_model=id_)
for model, id_ in SUPPORTED_MODELS.items()
]
def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[
ToolPromptFormat
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
if stream:
raise ValueError("Streamed completions are not supported")
await check_health(self._config) # this raises errors
request = ChatCompletionRequest(
model=SUPPORTED_MODELS[CoreModelId(model)],
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
async with httpx.AsyncClient(timeout=self._config.timeout) as client:
try:
response = await client.post(
f"{self._config.base_url}/v1/chat/completions",
headers=self._headers,
json=convert_chat_completion_request(request, n=1),
)
except httpx.ReadTimeout as e:
raise TimeoutError(
f"Request timed out. timeout set to {self._config.timeout}. Use `llama stack configure ...` to adjust it."
) from e
if response.status_code == 401:
raise PermissionError(
"Unauthorized. Please check your API key, reconfigure, and try again."
)
if response.status_code == 400:
raise ValueError(
f"Bad request. Please check the request and try again. Detail: {response.text}"
)
if response.status_code == 404:
raise ValueError(
"Model not found. Please check the model name and try again."
)
assert (
response.status_code == 200
), f"Failed to get completion: {response.text}"
# we pass n=1 to get only one completion
return parse_completion(response.json()["choices"][0])

View file

@ -0,0 +1,328 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from typing import Any, Dict, List, Optional, Tuple
import httpx
from llama_models.llama3.api.datatypes import (
CompletionMessage,
StopReason,
TokenLogProbs,
ToolCall,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Message,
)
from ._config import NVIDIAConfig
def convert_message(message: Message) -> dict:
"""
Convert a Message to an OpenAI API-compatible dictionary.
"""
out_dict = message.dict()
# Llama Stack uses role="ipython" for tool call messages, OpenAI uses "tool"
if out_dict["role"] == "ipython":
out_dict.update(role="tool")
if "stop_reason" in out_dict:
out_dict.update(stop_reason=out_dict["stop_reason"].value)
# TODO(mf): tool_calls
return out_dict
async def _get_health(url: str) -> Tuple[bool, bool]:
"""
Query {url}/v1/health/{live,ready} to check if the server is running and ready
Args:
url (str): URL of the server
Returns:
Tuple[bool, bool]: (is_live, is_ready)
"""
async with httpx.AsyncClient() as client:
live = await client.get(f"{url}/v1/health/live")
ready = await client.get(f"{url}/v1/health/ready")
return live.status_code == 200, ready.status_code == 200
async def check_health(config: NVIDIAConfig) -> None:
"""
Check if the server is running and ready
Args:
url (str): URL of the server
Raises:
RuntimeError: If the server is not running or ready
"""
if not config.is_hosted:
print("Checking NVIDIA NIM health...")
try:
is_live, is_ready = await _get_health(config.base_url)
if not is_live:
raise ConnectionError("NVIDIA NIM is not running")
if not is_ready:
raise ConnectionError("NVIDIA NIM is not ready")
# TODO(mf): should we wait for the server to be ready?
except httpx.ConnectError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e
def convert_chat_completion_request(
request: ChatCompletionRequest,
n: int = 1,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
# model -> model
# messages -> messages
# sampling_params TODO(mattf): review strategy
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
# strategy=top_k -> nvext.top_k = top_k
# temperature -> temperature
# top_p -> top_p
# top_k -> nvext.top_k
# max_tokens -> max_tokens
# repetition_penalty -> nvext.repetition_penalty
# tools -> tools
# tool_choice ("auto", "required") -> tool_choice
# tool_prompt_format -> TBD
# stream -> stream
# logprobs -> logprobs
print(f"sampling_params: {request.sampling_params}")
payload: Dict[str, Any] = dict(
model=request.model,
messages=[convert_message(message) for message in request.messages],
stream=request.stream,
nvext={},
n=n,
)
nvext = payload["nvext"]
if request.tools:
payload.update(tools=request.tools)
if request.tool_choice:
payload.update(
tool_choice=request.tool_choice.value
) # we cannot include tool_choice w/o tools, server will complain
if request.logprobs:
payload.update(logprobs=True)
payload.update(top_logprobs=request.logprobs.top_k)
if request.sampling_params:
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens)
if request.sampling_params.strategy == "top_p":
nvext.update(top_k=-1)
payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k":
if (
request.sampling_params.top_k != -1
and request.sampling_params.top_k < 1
):
warnings.warn("top_k must be -1 or >= 1")
nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=-1)
payload.update(temperature=request.sampling_params.temperature)
return payload
def _parse_content(completion: dict) -> str:
"""
Get the content from an OpenAI completion response.
OpenAI completion response format -
{
...
"message": {"role": "assistant", "content": ..., ...},
...
}
"""
# content is nullable in the OpenAI response, common for tool calls
return completion["message"]["content"] or ""
def _parse_stop_reason(completion: dict) -> StopReason:
"""
Get the StopReason from an OpenAI completion response.
OpenAI completion response format -
{
...
"finish_reason": "length" or "stop" or "tool_calls",
...
}
"""
# StopReason options are end_of_turn, end_of_message, out_of_tokens
# TODO(mf): is end_of_turn and end_of_message usage correct?
stop_reason = StopReason.end_of_turn
if completion["finish_reason"] == "length":
stop_reason = StopReason.out_of_tokens
elif completion["finish_reason"] == "stop":
stop_reason = StopReason.end_of_message
elif completion["finish_reason"] == "tool_calls":
stop_reason = StopReason.end_of_turn
return stop_reason
def _parse_tool_calls(completion: dict) -> List[ToolCall]:
"""
Get the tool calls from an OpenAI completion response.
OpenAI completion response format -
{
...,
"message": {
...,
"tool_calls": [
{
"id": X,
"type": "function",
"function": {
"name": Y,
"arguments": Z,
},
}*
],
},
}
->
[
ToolCall(call_id=X, tool_name=Y, arguments=Z),
...
]
"""
tool_calls = []
if "tool_calls" in completion["message"]:
assert isinstance(
completion["message"]["tool_calls"], list
), "error in server response: tool_calls not a list"
for call in completion["message"]["tool_calls"]:
assert "id" in call, "error in server response: tool call id not found"
assert (
"function" in call
), "error in server response: tool call function not found"
assert (
"name" in call["function"]
), "error in server response: tool call function name not found"
assert (
"arguments" in call["function"]
), "error in server response: tool call function arguments not found"
tool_calls.append(
ToolCall(
call_id=call["id"],
tool_name=call["function"]["name"],
arguments=call["function"]["arguments"],
)
)
return tool_calls
def _parse_logprobs(completion: dict) -> Optional[List[TokenLogProbs]]:
"""
Extract logprobs from OpenAI as a list of TokenLogProbs.
OpenAI completion response format -
{
...
"logprobs": {
content: [
{
...,
top_logprobs: [{token: X, logprob: Y, bytes: [...]}+]
}+
]
},
...
}
->
[
TokenLogProbs(
logprobs_by_token={X: Y, ...}
),
...
]
"""
if not (logprobs := completion.get("logprobs")):
return None
return [
TokenLogProbs(
logprobs_by_token={
logprobs["token"]: logprobs["logprob"]
for logprobs in content["top_logprobs"]
}
)
for content in logprobs["content"]
]
def parse_completion(
completion: dict,
) -> ChatCompletionResponse:
"""
Parse an OpenAI completion response into a CompletionMessage and logprobs.
OpenAI completion response format -
{
"message": {
"role": "assistant",
"content": ...,
"tool_calls": [
{
...
"id": ...,
"function": {
"name": ...,
"arguments": ...,
},
}*
]?,
"finish_reason": ...,
"logprobs": {
"content": [
{
...,
"top_logprobs": [{"token": ..., "logprob": ..., ...}+]
}+
]
}?
}
"""
assert "message" in completion, "error in server response: message not found"
assert (
"finish_reason" in completion
), "error in server response: finish_reason not found"
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=_parse_content(completion),
stop_reason=_parse_stop_reason(completion),
tool_calls=_parse_tool_calls(completion),
),
logprobs=_parse_logprobs(completion),
)

View file

@ -140,6 +140,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.inference.databricks.DatabricksImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=[], # TODO(mf): need to specify httpx if it's already a llama-stack dep?
module="llama_stack.providers.adapters.inference.nvidia",
config_class="llama_stack.providers.adapters.inference.nvidia.NVIDIAConfig",
),
),
InlineProviderSpec(
api=Api.inference,
provider_type="vllm",

26
tests/nvidia/README.md Normal file
View file

@ -0,0 +1,26 @@
# NVIDIA tests
## Running tests
**Install the required dependencies:**
```bash
pip install pytest pytest-asyncio pytest-httpx
```
There are three modes for testing:
1. Unit tests - this mode checks the provider functionality and does not require a network connection or running distribution
```bash
pytest tests/nvidia/unit
```
2. Integration tests against hosted preview APIs - this mode checks the provider functionality against a live system and requires an API key. Get an API key by 0. going to https://build.nvidia.com, 1. selecting a Llama model, e.g. https://build.nvidia.com/meta/llama-3_1-8b-instruct, and 2. clicking "Get API Key". Store the API key in the `NVIDIA_API_KEY` environment variable.
```bash
export NVIDIA_API_KEY=...
pytest tests/nvidia/integration --base-url https://integrate.api.nvidia.com
```
3. Integration tests against a running distribution - this mode checks the provider functionality in the context of a running distribution. This involves running a local NIM, see https://build.nvidia.com/meta/llama-3_1-8b-instruct?snippet_tab=Docker, and creating & configuring a distribution to use it. Details to come.

View file

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
from llama_stack.apis.inference import Inference
from llama_stack.providers.adapters.inference.nvidia import (
get_adapter_impl,
NVIDIAConfig,
)
def pytest_collection_modifyitems(config, items):
"""
Skip all integration tests if NVIDIA_API_KEY is not set and --base-url
includes "https://integrate.api.nvidia.com". It is needed to access the
hosted preview APIs.
"""
if "integrate.api.nvidia.com" in config.getoption(
"--base-url"
) and not os.environ.get("NVIDIA_API_KEY"):
skip_nvidia = pytest.mark.skip(
reason="NVIDIA_API_KEY environment variable must be set to access integrate.api.nvidia.com"
)
for item in items:
item.add_marker(skip_nvidia)
def pytest_addoption(parser):
parser.addoption(
"--base-url",
action="store",
default="http://localhost:8000",
help="Base URL for the tests",
)
parser.addoption(
"--model",
action="store",
default="Llama-3-8B-Instruct",
help="Model option for the tests",
)
@pytest.fixture
def base_url(request):
return request.config.getoption("--base-url")
@pytest.fixture
def model(request):
return request.config.getoption("--model")
@pytest.fixture
def client(base_url: str) -> Inference:
return get_adapter_impl(
NVIDIAConfig(
base_url=base_url,
api_key=os.environ.get("NVIDIA_API_KEY"),
),
{},
)

View file

@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import itertools
from typing import Generator, List, Tuple
import pytest
from llama_stack.apis.inference import (
ChatCompletionResponse,
CompletionMessage,
Inference,
Message,
StopReason,
SystemMessage,
ToolResponseMessage,
UserMessage,
)
from llama_stack.providers.adapters.inference.nvidia import (
get_adapter_impl,
NVIDIAConfig,
)
pytestmark = pytest.mark.asyncio
# TODO(mf): test bad creds raises PermissionError
# TODO(mf): test bad params, e.g. max_tokens=0 raises ValidationError
# TODO(mf): test bad model name raises ValueError
# TODO(mf): test short timeout raises TimeoutError
# TODO(mf): new file, test cli model listing
# TODO(mf): test streaming
# TODO(mf): test tool calls w/ tool_choice
def message_combinations(
length: int,
) -> Generator[Tuple[List[Message], str], None, None]:
"""
Generate all possible combinations of message types of given length.
"""
message_types = [
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
]
for count in range(1, length + 1):
for combo in itertools.product(message_types, repeat=count):
messages = []
for i, msg in enumerate(combo):
if msg == ToolResponseMessage:
messages.append(
msg(
content=f"Message {i + 1}",
call_id=f"call_{i + 1}",
tool_name=f"tool_{i + 1}",
)
)
elif msg == CompletionMessage:
messages.append(
msg(content=f"Message {i + 1}", stop_reason="end_of_message")
)
else:
messages.append(msg(content=f"Message {i + 1}"))
id_str = "-".join([msg.__name__ for msg in combo])
yield messages, id_str
@pytest.mark.parametrize("combo", message_combinations(3), ids=lambda x: x[1])
async def test_chat_completion_messages(
client: Inference,
model: str,
combo: Tuple[List[Message], str],
):
"""
Test the chat completion endpoint with different message combinations.
"""
client = await client
messages, _ = combo
response = await client.chat_completion(
model=model,
messages=messages,
stream=False,
)
assert isinstance(response, ChatCompletionResponse)
assert isinstance(response.completion_message.content, str)
# we're not testing accuracy, so no assertions on the result.completion_message.content
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.stop_reason, StopReason)
assert response.completion_message.tool_calls == []
async def test_bad_base_url(
model: str,
):
"""
Test that a bad base_url raises a ConnectionError.
"""
client = await get_adapter_impl(
NVIDIAConfig(
base_url="http://localhost:32123",
),
{},
)
with pytest.raises(ConnectionError):
await client.chat_completion(
model=model,
messages=[UserMessage(content="Hello")],
stream=False,
)

View file

@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
from llama_stack.apis.inference import Inference
from llama_stack.providers.adapters.inference.nvidia import (
get_adapter_impl,
NVIDIAConfig,
)
from pytest_httpx import HTTPXMock
pytestmark = pytest.mark.asyncio
@pytest.fixture
def base_url():
return "http://endpoint.mocked"
@pytest.fixture
def client(base_url: str) -> Inference:
return get_adapter_impl(
NVIDIAConfig(
base_url=base_url,
api_key=os.environ.get("NVIDIA_API_KEY"),
),
{},
)
@pytest.fixture
def mock_health(
httpx_mock: HTTPXMock,
base_url: str,
) -> HTTPXMock:
for path in [
"/v1/health/live",
"/v1/health/ready",
]:
httpx_mock.add_response(
url=f"{base_url}{path}",
status_code=200,
)
return httpx_mock
@pytest.fixture
def mock_chat_completion(httpx_mock: HTTPXMock, base_url: str) -> HTTPXMock:
httpx_mock.add_response(
url=f"{base_url}/v1/chat/completions",
json={
"id": "mock-id",
"created": 1234567890,
"object": "chat.completion",
"model": "mock-model",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "WORKED"},
"finish_reason": "length",
}
],
},
status_code=200,
)
return httpx_mock

View file

@ -0,0 +1,203 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_models.llama3.api.datatypes import TokenLogProbs, ToolCall
from llama_stack.apis.inference import Inference
from pytest_httpx import HTTPXMock
pytestmark = pytest.mark.asyncio
async def test_content(
mock_health: HTTPXMock,
httpx_mock: HTTPXMock,
client: Inference,
base_url: str,
) -> None:
"""
Test that response content makes it through to the completion message.
"""
httpx_mock.add_response(
url=f"{base_url}/v1/chat/completions",
json={
"id": "mock-id",
"created": 1234567890,
"object": "chat.completion",
"model": "mock-model",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "RESPONSE"},
"finish_reason": "length",
}
],
},
status_code=200,
)
client = await client
response = await client.chat_completion(
model="Llama-3-8B-Instruct",
messages=[{"role": "user", "content": "BOGUS"}],
stream=False,
)
assert response.completion_message.content == "RESPONSE"
async def test_logprobs(
mock_health: HTTPXMock,
httpx_mock: HTTPXMock,
client: Inference,
base_url: str,
) -> None:
"""
Test that logprobs are parsed correctly.
"""
httpx_mock.add_response(
url=f"{base_url}/v1/chat/completions",
json={
"id": "mock-id",
"object": "chat.completion",
"created": 1234567890,
"model": "mock-model",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Hello there"},
"logprobs": {
"content": [
{
"token": "Hello",
"logprob": -0.1,
"bytes": [72, 101, 108, 108, 111],
"top_logprobs": [
{"token": "Hello", "logprob": -0.1},
{"token": "Hi", "logprob": -1.2},
{"token": "Greetings", "logprob": -2.1},
],
},
{
"token": "there",
"logprob": -0.2,
"bytes": [116, 104, 101, 114, 101],
"top_logprobs": [
{"token": "there", "logprob": -0.2},
{"token": "here", "logprob": -1.3},
{"token": "where", "logprob": -2.2},
],
},
]
},
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
},
status_code=200,
)
client = await client
response = await client.chat_completion(
model="Llama-3-8B-Instruct",
messages=[{"role": "user", "content": "Hello"}],
logprobs={"top_k": 3},
stream=False,
)
assert response.logprobs == [
TokenLogProbs(
logprobs_by_token={
"Hello": -0.1,
"Hi": -1.2,
"Greetings": -2.1,
}
),
TokenLogProbs(
logprobs_by_token={
"there": -0.2,
"here": -1.3,
"where": -2.2,
}
),
]
async def test_tools(
mock_health: HTTPXMock,
httpx_mock: HTTPXMock,
client: Inference,
base_url: str,
) -> None:
"""
Test that tools are passed correctly.
"""
httpx_mock.add_response(
url=f"{base_url}/v1/chat/completions",
json={
"id": "mock-id",
"object": "chat.completion",
"created": 1234567890,
"model": "mock-model",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "tool-id",
"type": "function",
"function": {
"name": "magic",
"arguments": {"input": 3},
},
},
{
"id": "tool-id!",
"type": "function",
"function": {
"name": "magic!",
"arguments": {"input": 42},
},
},
],
},
"logprobs": None,
"finish_reason": "tool_calls",
}
],
},
status_code=200,
)
client = await client
response = await client.chat_completion(
model="Llama-3-8B-Instruct",
messages=[{"role": "user", "content": "Hello"}],
stream=False,
)
assert response.completion_message.tool_calls == [
ToolCall(
call_id="tool-id",
tool_name="magic",
arguments={"input": 3},
),
ToolCall(
call_id="tool-id!",
tool_name="magic!",
arguments={"input": 42},
),
]
# TODO(mf): test stream=True for each case

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.inference import Inference
from pytest_httpx import HTTPXMock
pytestmark = pytest.mark.asyncio
async def test_chat_completion(
mock_health: HTTPXMock,
mock_chat_completion: HTTPXMock,
client: Inference,
base_url: str,
) -> None:
"""
Test that health endpoints are checked when chat_completion is called.
"""
client = await client
await client.chat_completion(
model="Llama-3-8B-Instruct",
messages=[{"role": "user", "content": "BOGUS"}],
stream=False,
)
# TODO(mf): test stream=True for each case
# TODO(mf): test completion
# TODO(mf): test embedding

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.adapters.inference.nvidia import __all__
def test_import():
assert set(__all__) == {"get_adapter_impl", "NVIDIAConfig"}