diff --git a/llama_stack/providers/remote/inference/ssambanova/__init__.py b/llama_stack/providers/remote/inference/ssambanova/__init__.py new file mode 100644 index 000000000..a8c5eb0d5 --- /dev/null +++ b/llama_stack/providers/remote/inference/ssambanova/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import SsambanovaImplConfig +from .ssambanova import SsambanovaInferenceAdapter + + +async def get_adapter_impl(config: SsambanovaImplConfig, _deps): + assert isinstance( + config, SsambanovaImplConfig + ), f"Unexpected config type: {type(config)}" + impl = SsambanovaInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/inference/ssambanova/config.py b/llama_stack/providers/remote/inference/ssambanova/config.py new file mode 100644 index 000000000..532670d87 --- /dev/null +++ b/llama_stack/providers/remote/inference/ssambanova/config.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class SsambanovaImplConfig(BaseModel): + url: str = Field( + default="https://api.sambanova.ai/v1", + description="The URL for the Ssambanova model serving endpoint", + ) + api_token: str = Field( + default=None, + description="The Ssambanova API token", + ) diff --git a/llama_stack/providers/remote/inference/ssambanova/ssambanova.py b/llama_stack/providers/remote/inference/ssambanova/ssambanova.py new file mode 100644 index 000000000..c25e5ffe5 --- /dev/null +++ b/llama_stack/providers/remote/inference/ssambanova/ssambanova.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator + +from llama_models.datatypes import CoreModelId + +from llama_models.llama3.api.chat_format import ChatFormat + +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer + +from openai import OpenAI + +from llama_stack.apis.inference import * # noqa: F403 + +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, +) + +from .config import SsambanovaImplConfig + + +model_aliases = [ + build_model_alias( + "Meta-Llama-3.1-8B-Instruct", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "Meta-Llama-3.1-70B-Instruct", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_model_alias( + "Meta-Llama-3.1-405B-Instruct", + CoreModelId.llama3_1_405b_instruct.value, + ), + build_model_alias( + "Meta-Llama-3.2-1B-Instruct", + CoreModelId.llama3_2_1b_instruct.value, + ), + build_model_alias( + "Meta-Llama-3.2-3B-Instruct", + CoreModelId.llama3_2_3b_instruct.value, + ), +] + + +class SsambanovaInferenceAdapter(ModelRegistryHelper, Inference): + def __init__(self, config: SsambanovaImplConfig) -> None: + ModelRegistryHelper.__init__( + self, + model_aliases=model_aliases, + ) + self.config = config + self.formatter = ChatFormat(Tokenizer.get_instance()) + + async def initialize(self) -> None: + return + + async def shutdown(self) -> None: + pass + + async def completion( + self, + model: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + raise NotImplementedError() + + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + + client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) + if stream: + return self._stream_chat_completion(request, client) + else: + return await self._nonstream_chat_completion(request, client) + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest, client: OpenAI + ) -> ChatCompletionResponse: + params = self._get_params(request) + r = client.completions.create(**params) + return process_chat_completion_response(r, self.formatter) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest, client: OpenAI + ) -> AsyncGenerator: + params = self._get_params(request) + + async def _to_async_generator(): + s = client.completions.create(**params) + for chunk in s: + yield chunk + + stream = _to_async_generator() + async for chunk in process_chat_completion_stream_response( + stream, self.formatter + ): + yield chunk + + def _get_params(self, request: ChatCompletionRequest) -> dict: + return { + "model": request.model, + "prompt": chat_completion_request_to_prompt( + request, self.get_llama_model(request.model), self.formatter + ), + "stream": request.stream, + **get_sampling_options(request.sampling_params), + } + + async def embeddings( + self, + model: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError()