mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Nutanix AI on!
This commit is contained in:
parent
1e2faa461f
commit
64c5d38ae9
10 changed files with 234 additions and 2 deletions
|
@ -88,6 +88,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
|
|||
| Chroma | Single Node | | | :heavy_check_mark: | | |
|
||||
| PG Vector | Single Node | | | :heavy_check_mark: | | |
|
||||
| PyTorch ExecuTorch | On-device iOS | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| Nutanix AI | Hosted | | :heavy_check_mark: | | | |
|
||||
|
||||
### Distributions
|
||||
|
||||
|
|
40
distributions/nutanix/README.md
Normal file
40
distributions/nutanix/README.md
Normal file
|
@ -0,0 +1,40 @@
|
|||
# Nutanix Distribution
|
||||
|
||||
The `llamastack/distribution-nutanix` distribution consists of the following provider configurations.
|
||||
|
||||
|
||||
| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|
||||
|----------------- |--------------- |---------------- |-------------------------------------------------- |---------------- |---------------- |
|
||||
| **Provider(s)** | remote::nutanix | meta-reference | meta-reference | meta-reference | meta-reference |
|
||||
|
||||
|
||||
### Start the Distribution (Hosted remote)
|
||||
|
||||
> [!NOTE]
|
||||
> This assumes you have an hosted Nutanix AI endpoint and an API Key.
|
||||
|
||||
1. Clone the repo
|
||||
```
|
||||
git clone git@github.com:meta-llama/llama-stack.git
|
||||
cd llama-stack
|
||||
```
|
||||
|
||||
2. Config the model name
|
||||
|
||||
Please adjust the `NUTANIX_SUPPORTED_MODELS` variable at line 29 in `llama_stack/providers/adapters/inference/nutanix/nutanix.py` according to your deployment.
|
||||
|
||||
3. Build the distrbution
|
||||
```
|
||||
pip install -e .
|
||||
llama stack build --template nutanix --name ntnx --image-type conda
|
||||
```
|
||||
|
||||
4. Set the endpoint URL and API Key
|
||||
```
|
||||
llama stack configure ntnx
|
||||
```
|
||||
|
||||
5. Serve and enjoy!
|
||||
```
|
||||
llama stack run ntnx --port 174
|
||||
```
|
1
distributions/nutanix/build.yaml
Normal file
1
distributions/nutanix/build.yaml
Normal file
|
@ -0,0 +1 @@
|
|||
../../llama_stack/templates/nutanix/build.yaml
|
16
llama_stack/providers/adapters/inference/nutanix/__init__.py
Normal file
16
llama_stack/providers/adapters/inference/nutanix/__init__.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import NutanixImplConfig
|
||||
from .nutanix import NutanixInferenceAdapter
|
||||
|
||||
async def get_adapter_impl(config: NutanixInferenceAdapter, _deps):
|
||||
assert isinstance(
|
||||
config, NutanixImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = NutanixInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
22
llama_stack/providers/adapters/inference/nutanix/config.py
Normal file
22
llama_stack/providers/adapters/inference/nutanix/config.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NutanixImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default=None,
|
||||
description="The URL of the Nutanix AI endpoint",
|
||||
)
|
||||
api_token: str = Field(
|
||||
default=None,
|
||||
description="The API token of the Nutanix AI endpoint",
|
||||
)
|
125
llama_stack/providers/adapters/inference/nutanix/nutanix.py
Normal file
125
llama_stack/providers/adapters/inference/nutanix/nutanix.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import NutanixImplConfig
|
||||
|
||||
NUTANIX_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "vllm-llama-3-1",
|
||||
}
|
||||
|
||||
|
||||
class NutanixInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: NutanixImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=NUTANIX_SUPPORTED_MODELS
|
||||
)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.chat.completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _to_async_generator():
|
||||
s = client.chat.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
params = {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"messages": chat_completion_request_to_messages(request, return_dict=True),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
|
@ -161,4 +161,15 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nutanix",
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.nutanix",
|
||||
config_class="llama_stack.providers.remote.inference.nutanix.NutanixImplConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -43,6 +43,8 @@ def get_sampling_options(params: SamplingParams) -> dict:
|
|||
|
||||
|
||||
def text_from_choice(choice) -> str:
|
||||
if hasattr(choice, "message") and choice.message:
|
||||
return choice.message.content
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
return choice.delta.content
|
||||
|
||||
|
|
|
@ -169,10 +169,12 @@ def chat_completion_request_to_model_input_info(
|
|||
def chat_completion_request_to_messages(
|
||||
request: ChatCompletionRequest,
|
||||
llama_model: str,
|
||||
) -> List[Message]:
|
||||
return_dict: bool = False,
|
||||
) -> Union[List[Message], List[Dict[str, str]]]:
|
||||
"""Reads chat completion request and augments the messages to handle tools.
|
||||
For eg. for llama_3_1, add system message with the appropriate tools or
|
||||
add user messsage for custom tools, etc.
|
||||
If return_dict is set, returns a list of the messages dictionaries instead of objects.
|
||||
"""
|
||||
model = resolve_model(llama_model)
|
||||
if model is None:
|
||||
|
@ -199,7 +201,10 @@ def chat_completion_request_to_messages(
|
|||
if fmt_prompt := response_format_prompt(request.response_format):
|
||||
messages.append(UserMessage(content=fmt_prompt))
|
||||
|
||||
return messages
|
||||
if return_dict:
|
||||
return [{'role': m.role, 'content': m.content} for m in messages]
|
||||
else:
|
||||
return messages
|
||||
|
||||
|
||||
def response_format_prompt(fmt: Optional[ResponseFormat]):
|
||||
|
|
9
llama_stack/templates/nutanix/build.yaml
Normal file
9
llama_stack/templates/nutanix/build.yaml
Normal file
|
@ -0,0 +1,9 @@
|
|||
name: nutanix
|
||||
distribution_spec:
|
||||
description: Use Nutanix AI Endpoint for running LLM inference
|
||||
providers:
|
||||
inference: remote::nutanix
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
Loading…
Add table
Add a link
Reference in a new issue