diff --git a/llama_stack/providers/adapters/inference/nutanix/__init__.py b/llama_stack/providers/remote/inference/nutanix/__init__.py similarity index 99% rename from llama_stack/providers/adapters/inference/nutanix/__init__.py rename to llama_stack/providers/remote/inference/nutanix/__init__.py index ef1dc10cf..013f7d560 100644 --- a/llama_stack/providers/adapters/inference/nutanix/__init__.py +++ b/llama_stack/providers/remote/inference/nutanix/__init__.py @@ -7,6 +7,7 @@ from .config import NutanixImplConfig from .nutanix import NutanixInferenceAdapter + async def get_adapter_impl(config: NutanixInferenceAdapter, _deps): assert isinstance( config, NutanixImplConfig diff --git a/llama_stack/providers/adapters/inference/nutanix/config.py b/llama_stack/providers/remote/inference/nutanix/config.py similarity index 95% rename from llama_stack/providers/adapters/inference/nutanix/config.py rename to llama_stack/providers/remote/inference/nutanix/config.py index 901c13cab..6af6af3ff 100644 --- a/llama_stack/providers/adapters/inference/nutanix/config.py +++ b/llama_stack/providers/remote/inference/nutanix/config.py @@ -4,8 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional - from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field diff --git a/llama_stack/providers/adapters/inference/nutanix/nutanix.py b/llama_stack/providers/remote/inference/nutanix/nutanix.py similarity index 88% rename from llama_stack/providers/adapters/inference/nutanix/nutanix.py rename to llama_stack/providers/remote/inference/nutanix/nutanix.py index 56dca43cd..8d9814659 100644 --- a/llama_stack/providers/adapters/inference/nutanix/nutanix.py +++ b/llama_stack/providers/remote/inference/nutanix/nutanix.py @@ -14,7 +14,10 @@ from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -26,16 +29,18 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import NutanixImplConfig -NUTANIX_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "vllm-llama-3-1", -} + +model_aliases = [ + build_model_alias( + "vllm-llama-3-1", + CoreModelId.llama3_1_8b_instruct.value, + ), +] class NutanixInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: NutanixImplConfig) -> None: - ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=NUTANIX_SUPPORTED_MODELS - ) + ModelRegistryHelper.__init__(self, model_aliases) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -47,7 +52,7 @@ class NutanixInferenceAdapter(ModelRegistryHelper, Inference): async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -58,7 +63,7 @@ class NutanixInferenceAdapter(ModelRegistryHelper, Inference): async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -110,8 +115,10 @@ class NutanixInferenceAdapter(ModelRegistryHelper, Inference): def _get_params(self, request: ChatCompletionRequest) -> dict: params = { - "model": self.map_to_provider_model(request.model), - "messages": chat_completion_request_to_messages(request, return_dict=True), + "model": request.model, + "messages": chat_completion_request_to_messages( + request, self.get_llama_model(request.model), return_dict=True + ), "stream": request.stream, **get_sampling_options(request.sampling_params), } @@ -119,7 +126,7 @@ class NutanixInferenceAdapter(ModelRegistryHelper, Inference): async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/templates/nutanix/build.yaml b/llama_stack/templates/nutanix/build.yaml index 16785c786..e3fcea06a 100644 --- a/llama_stack/templates/nutanix/build.yaml +++ b/llama_stack/templates/nutanix/build.yaml @@ -3,7 +3,7 @@ distribution_spec: description: Use Nutanix AI Endpoint for running LLM inference providers: inference: remote::nutanix - memory: meta-reference - safety: meta-reference - agents: meta-reference - telemetry: meta-reference + memory: inline::faiss + safety: inline::llama-guard + agents: inline::meta-reference + telemetry: inline::meta-reference