Nutanix AI on!

This commit is contained in:
Jinan Zhou 2024-10-29 22:43:52 +00:00
parent 1e2faa461f
commit 64c5d38ae9
10 changed files with 234 additions and 2 deletions

View file

@ -88,6 +88,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
| Chroma | Single Node | | | :heavy_check_mark: | | |
| PG Vector | Single Node | | | :heavy_check_mark: | | |
| PyTorch ExecuTorch | On-device iOS | :heavy_check_mark: | :heavy_check_mark: | | |
| Nutanix AI | Hosted | | :heavy_check_mark: | | | |
### Distributions

View file

@ -0,0 +1,40 @@
# Nutanix Distribution
The `llamastack/distribution-nutanix` distribution consists of the following provider configurations.
| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|----------------- |--------------- |---------------- |-------------------------------------------------- |---------------- |---------------- |
| **Provider(s)** | remote::nutanix | meta-reference | meta-reference | meta-reference | meta-reference |
### Start the Distribution (Hosted remote)
> [!NOTE]
> This assumes you have an hosted Nutanix AI endpoint and an API Key.
1. Clone the repo
```
git clone git@github.com:meta-llama/llama-stack.git
cd llama-stack
```
2. Config the model name
Please adjust the `NUTANIX_SUPPORTED_MODELS` variable at line 29 in `llama_stack/providers/adapters/inference/nutanix/nutanix.py` according to your deployment.
3. Build the distrbution
```
pip install -e .
llama stack build --template nutanix --name ntnx --image-type conda
```
4. Set the endpoint URL and API Key
```
llama stack configure ntnx
```
5. Serve and enjoy!
```
llama stack run ntnx --port 174
```

View file

@ -0,0 +1 @@
../../llama_stack/templates/nutanix/build.yaml

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import NutanixImplConfig
from .nutanix import NutanixInferenceAdapter
async def get_adapter_impl(config: NutanixInferenceAdapter, _deps):
assert isinstance(
config, NutanixImplConfig
), f"Unexpected config type: {type(config)}"
impl = NutanixInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class NutanixImplConfig(BaseModel):
url: str = Field(
default=None,
description="The URL of the Nutanix AI endpoint",
)
api_token: str = Field(
default=None,
description="The API token of the Nutanix AI endpoint",
)

View file

@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from openai import OpenAI
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import NutanixImplConfig
NUTANIX_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "vllm-llama-3-1",
}
class NutanixInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: NutanixImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=NUTANIX_SUPPORTED_MODELS
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.chat.completions.create(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():
s = client.chat.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
params = {
"model": self.map_to_provider_model(request.model),
"messages": chat_completion_request_to_messages(request, return_dict=True),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
return params
async def embeddings(
self,
model: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -161,4 +161,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="nutanix",
pip_packages=[
"openai",
],
module="llama_stack.providers.remote.inference.nutanix",
config_class="llama_stack.providers.remote.inference.nutanix.NutanixImplConfig",
),
),
]

View file

@ -43,6 +43,8 @@ def get_sampling_options(params: SamplingParams) -> dict:
def text_from_choice(choice) -> str:
if hasattr(choice, "message") and choice.message:
return choice.message.content
if hasattr(choice, "delta") and choice.delta:
return choice.delta.content

View file

@ -169,10 +169,12 @@ def chat_completion_request_to_model_input_info(
def chat_completion_request_to_messages(
request: ChatCompletionRequest,
llama_model: str,
) -> List[Message]:
return_dict: bool = False,
) -> Union[List[Message], List[Dict[str, str]]]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
If return_dict is set, returns a list of the messages dictionaries instead of objects.
"""
model = resolve_model(llama_model)
if model is None:
@ -199,7 +201,10 @@ def chat_completion_request_to_messages(
if fmt_prompt := response_format_prompt(request.response_format):
messages.append(UserMessage(content=fmt_prompt))
return messages
if return_dict:
return [{'role': m.role, 'content': m.content} for m in messages]
else:
return messages
def response_format_prompt(fmt: Optional[ResponseFormat]):

View file

@ -0,0 +1,9 @@
name: nutanix
distribution_spec:
description: Use Nutanix AI Endpoint for running LLM inference
providers:
inference: remote::nutanix
memory: meta-reference
safety: meta-reference
agents: meta-reference
telemetry: meta-reference