diff --git a/README.md b/README.md index 8e57292c3..b1d8e5e87 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,7 @@ Additionally, we have designed every element of the Stack such that APIs as well | Chroma | Single Node | | | :heavy_check_mark: | | | | PG Vector | Single Node | | | :heavy_check_mark: | | | | PyTorch ExecuTorch | On-device iOS | :heavy_check_mark: | :heavy_check_mark: | | | +| Nutanix AI | Hosted | | :heavy_check_mark: | | | | ### Distributions diff --git a/distributions/nutanix/README.md b/distributions/nutanix/README.md new file mode 100644 index 000000000..fdd9e3106 --- /dev/null +++ b/distributions/nutanix/README.md @@ -0,0 +1,40 @@ +# Nutanix Distribution + +The `llamastack/distribution-nutanix` distribution consists of the following provider configurations. + + +| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** | +|----------------- |--------------- |---------------- |-------------------------------------------------- |---------------- |---------------- | +| **Provider(s)** | remote::nutanix | meta-reference | meta-reference | meta-reference | meta-reference | + + +### Start the Distribution (Hosted remote) + +> [!NOTE] +> This assumes you have an hosted Nutanix AI endpoint and an API Key. + +1. Clone the repo +``` +git clone git@github.com:meta-llama/llama-stack.git +cd llama-stack +``` + +2. Config the model name + +Please adjust the `NUTANIX_SUPPORTED_MODELS` variable at line 29 in `llama_stack/providers/adapters/inference/nutanix/nutanix.py` according to your deployment. + +3. Build the distrbution +``` +pip install -e . +llama stack build --template nutanix --name ntnx --image-type conda +``` + +4. Set the endpoint URL and API Key +``` +llama stack configure ntnx +``` + +5. Serve and enjoy! +``` +llama stack run ntnx --port 174 +``` diff --git a/distributions/nutanix/build.yaml b/distributions/nutanix/build.yaml new file mode 100644 index 000000000..e6ad2e304 --- /dev/null +++ b/distributions/nutanix/build.yaml @@ -0,0 +1 @@ +../../llama_stack/templates/nutanix/build.yaml diff --git a/llama_stack/providers/adapters/inference/nutanix/__init__.py b/llama_stack/providers/adapters/inference/nutanix/__init__.py new file mode 100644 index 000000000..ef1dc10cf --- /dev/null +++ b/llama_stack/providers/adapters/inference/nutanix/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import NutanixImplConfig +from .nutanix import NutanixInferenceAdapter + +async def get_adapter_impl(config: NutanixInferenceAdapter, _deps): + assert isinstance( + config, NutanixImplConfig + ), f"Unexpected config type: {type(config)}" + impl = NutanixInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/inference/nutanix/config.py b/llama_stack/providers/adapters/inference/nutanix/config.py new file mode 100644 index 000000000..901c13cab --- /dev/null +++ b/llama_stack/providers/adapters/inference/nutanix/config.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class NutanixImplConfig(BaseModel): + url: str = Field( + default=None, + description="The URL of the Nutanix AI endpoint", + ) + api_token: str = Field( + default=None, + description="The API token of the Nutanix AI endpoint", + ) diff --git a/llama_stack/providers/adapters/inference/nutanix/nutanix.py b/llama_stack/providers/adapters/inference/nutanix/nutanix.py new file mode 100644 index 000000000..56dca43cd --- /dev/null +++ b/llama_stack/providers/adapters/inference/nutanix/nutanix.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator + +from openai import OpenAI + +from llama_models.llama3.api.chat_format import ChatFormat + +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer + +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_messages, +) + +from .config import NutanixImplConfig + +NUTANIX_SUPPORTED_MODELS = { + "Llama3.1-8B-Instruct": "vllm-llama-3-1", +} + + +class NutanixInferenceAdapter(ModelRegistryHelper, Inference): + def __init__(self, config: NutanixImplConfig) -> None: + ModelRegistryHelper.__init__( + self, stack_to_provider_models_map=NUTANIX_SUPPORTED_MODELS + ) + self.config = config + self.formatter = ChatFormat(Tokenizer.get_instance()) + + async def initialize(self) -> None: + return + + async def shutdown(self) -> None: + pass + + async def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + raise NotImplementedError() + + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + + client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) + if stream: + return self._stream_chat_completion(request, client) + else: + return await self._nonstream_chat_completion(request, client) + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest, client: OpenAI + ) -> ChatCompletionResponse: + params = self._get_params(request) + r = client.chat.completions.create(**params) + return process_chat_completion_response(r, self.formatter) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest, client: OpenAI + ) -> AsyncGenerator: + params = self._get_params(request) + + async def _to_async_generator(): + s = client.chat.completions.create(**params) + for chunk in s: + yield chunk + + stream = _to_async_generator() + async for chunk in process_chat_completion_stream_response( + stream, self.formatter + ): + yield chunk + + def _get_params(self, request: ChatCompletionRequest) -> dict: + params = { + "model": self.map_to_provider_model(request.model), + "messages": chat_completion_request_to_messages(request, return_dict=True), + "stream": request.stream, + **get_sampling_options(request.sampling_params), + } + return params + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index c8d061f6c..2ba6d1958 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -161,4 +161,15 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="nutanix", + pip_packages=[ + "openai", + ], + module="llama_stack.providers.remote.inference.nutanix", + config_class="llama_stack.providers.remote.inference.nutanix.NutanixImplConfig", + ), + ), ] diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index cc3e7a2ce..f7b9a5370 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -43,6 +43,8 @@ def get_sampling_options(params: SamplingParams) -> dict: def text_from_choice(choice) -> str: + if hasattr(choice, "message") and choice.message: + return choice.message.content if hasattr(choice, "delta") and choice.delta: return choice.delta.content diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index ca06e1b1f..fe8bfde02 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -169,10 +169,12 @@ def chat_completion_request_to_model_input_info( def chat_completion_request_to_messages( request: ChatCompletionRequest, llama_model: str, -) -> List[Message]: + return_dict: bool = False, +) -> Union[List[Message], List[Dict[str, str]]]: """Reads chat completion request and augments the messages to handle tools. For eg. for llama_3_1, add system message with the appropriate tools or add user messsage for custom tools, etc. + If return_dict is set, returns a list of the messages dictionaries instead of objects. """ model = resolve_model(llama_model) if model is None: @@ -199,7 +201,10 @@ def chat_completion_request_to_messages( if fmt_prompt := response_format_prompt(request.response_format): messages.append(UserMessage(content=fmt_prompt)) - return messages + if return_dict: + return [{'role': m.role, 'content': m.content} for m in messages] + else: + return messages def response_format_prompt(fmt: Optional[ResponseFormat]): diff --git a/llama_stack/templates/nutanix/build.yaml b/llama_stack/templates/nutanix/build.yaml new file mode 100644 index 000000000..16785c786 --- /dev/null +++ b/llama_stack/templates/nutanix/build.yaml @@ -0,0 +1,9 @@ +name: nutanix +distribution_spec: + description: Use Nutanix AI Endpoint for running LLM inference + providers: + inference: remote::nutanix + memory: meta-reference + safety: meta-reference + agents: meta-reference + telemetry: meta-reference