From 8920c4216fddf2000c54d3fd870e9b6832316c3c Mon Sep 17 00:00:00 2001 From: swanhtet1992 Date: Sun, 24 Nov 2024 01:55:36 -0600 Subject: [PATCH] Implement additional functionality supported by Sambanova. --- .../remote/inference/sambanova/__init__.py | 7 + .../remote/inference/sambanova/config.py | 6 + .../remote/inference/sambanova/sambanova.py | 608 ++++++++++++------ .../providers/tests/inference/fixtures.py | 17 + llama_stack/templates/sambanova/__init__.py | 8 +- llama_stack/templates/sambanova/build.yaml | 2 +- .../templates/sambanova/doc_template.md | 66 ++ llama_stack/templates/sambanova/run.yaml | 32 +- llama_stack/templates/sambanova/sambanova.py | 22 +- 9 files changed, 565 insertions(+), 203 deletions(-) create mode 100644 llama_stack/templates/sambanova/doc_template.md diff --git a/llama_stack/providers/remote/inference/sambanova/__init__.py b/llama_stack/providers/remote/inference/sambanova/__init__.py index 05cbd5d0e..fa5f13d45 100644 --- a/llama_stack/providers/remote/inference/sambanova/__init__.py +++ b/llama_stack/providers/remote/inference/sambanova/__init__.py @@ -1,7 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + from pydantic import BaseModel from .config import SambanovaImplConfig + class SambanovaProviderDataValidator(BaseModel): sambanova_api_key: str diff --git a/llama_stack/providers/remote/inference/sambanova/config.py b/llama_stack/providers/remote/inference/sambanova/config.py index 3f4f9cc19..eb1586218 100644 --- a/llama_stack/providers/remote/inference/sambanova/config.py +++ b/llama_stack/providers/remote/inference/sambanova/config.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + from typing import Any, Dict, Optional from llama_models.schema_utils import json_schema_type diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 6c3fa60ee..0648374dc 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -1,60 +1,173 @@ -from typing import AsyncGenerator, List, Optional, Union +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import AsyncGenerator, Dict, List, Optional, Union -import httpx from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer +from openai import AsyncOpenAI -from llama_stack.apis.inference import * +from llama_stack.apis.inference import ( + AsyncIterator, + ChatCompletionRequest, + CompletionRequest, + CompletionResponse, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + InterleavedTextMedia, + LogProbConfig, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_model_alias, ) from llama_stack.providers.utils.inference.openai_compat import ( + ChatCompletionResponseStreamChunk, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, - process_completion_response, - process_completion_stream_response, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - completion_request_to_prompt, ) from .config import SambanovaImplConfig -# Simplified model aliases - focus on core models + +class SambanovaErrorCode(str, Enum): + INVALID_AUTH = "invalid_authentication" + REQUEST_TIMEOUT = "request_timeout" + INSUFFICIENT_QUOTA = "insufficient_quota" + CONTEXT_LENGTH_EXCEEDED = "context_length_exceeded" + INVALID_TYPE = "invalid_type" + MODEL_NOT_FOUND = "model_not_found" + VALUE_ABOVE_MAX = "decimal_above_max_value" + VALUE_BELOW_MIN = "decimal_below_min_value" + INTEGER_ABOVE_MAX = "integer_above_max_value" + + MODEL_ALIASES = [ + build_model_alias( + "Meta-Llama-3.2-1B-Instruct", + CoreModelId.llama3_2_1b_instruct.value, + ), + build_model_alias( + "Meta-Llama-3.2-3B-Instruct", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_model_alias( + "Llama-3.2-11B-Vision-Instruct", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_model_alias( + "Llama-3.2-90B-Vision-Instruct", + CoreModelId.llama3_2_90b_vision_instruct.value, + ), build_model_alias( "Meta-Llama-3.1-8B-Instruct", CoreModelId.llama3_1_8b_instruct.value, ), + build_model_alias( + "Meta-Llama-3.1-70B-Instruct", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_model_alias( + "Meta-Llama-3.1-405B-Instruct", + CoreModelId.llama3_1_405b_instruct.value, + ), ] +FUNCTION_CALLING_MODELS = { + "Meta-Llama-3.1-8B-Instruct", + "Meta-Llama-3.1-70B-Instruct", + "Meta-Llama-3.1-405B-Instruct", +} + +UNSUPPORTED_PARAMS = { + "logprobs", + "top_logprobs", + "n", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "parallel_tool_calls", + "seed", + "response_format", +} + class SambanovaInferenceAdapter( ModelRegistryHelper, Inference, NeedsRequestProviderData ): + """SambaNova inference adapter using OpenAI client compatibility layer. + + This adapter provides access to SambaNova's AI models through their OpenAI-compatible API. + It handles authentication, request formatting, and response processing while managing + unsupported features gracefully. + + Note: Some OpenAI features are not supported: + - logprobs, top_logprobs, n + - presence_penalty, frequency_penalty + - logit_bias + - tools and tool_choice (function calling) + - parallel_tool_calls, seed + - stream_options + - response_format (JSON mode) + """ + def __init__(self, config: SambanovaImplConfig) -> None: + """Initialize the SambaNova inference adapter. + + Args: + config: Configuration for the SambaNova implementation + """ ModelRegistryHelper.__init__(self, MODEL_ALIASES) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) - self.client = httpx.AsyncClient( - base_url=self.config.url, - timeout=httpx.Timeout(timeout=300.0), - ) + self._client: Optional[AsyncOpenAI] = None + + @property + def client(self) -> AsyncOpenAI: + """Get or create the OpenAI client instance. + + Returns: + AsyncOpenAI: The configured client instance + """ + if self._client is None: + self._client = AsyncOpenAI( + base_url="https://api.sambanova.ai/v1", + api_key=self._get_api_key(), + timeout=60.0, + ) + return self._client async def initialize(self) -> None: pass async def shutdown(self) -> None: - await self.client.aclose() + pass def _get_api_key(self) -> str: + """Get the API key from config or request headers. + + Returns: + str: The API key to use + + Raises: + ValueError: If no API key is available + """ if self.config.api_key is not None: return self.config.api_key @@ -65,32 +178,261 @@ class SambanovaInferenceAdapter( ) return provider_data.sambanova_api_key - def _convert_messages_to_api_format(self, messages: List[Message]) -> List[dict]: - """Convert our Message objects to SambaNova API format.""" - return [ - {"role": message.role, "content": message.content} for message in messages - ] + def _filter_unsupported_params(self, params: Dict) -> Dict: + """Remove parameters not supported by SambaNova API. - def _get_sampling_params(self, params: Optional[SamplingParams]) -> dict: - """Convert our SamplingParams to SambaNova API parameters.""" - if not params: - return {} + Args: + params: Original parameters dictionary - api_params = {} - if params.max_tokens: - api_params["max_tokens"] = params.max_tokens - if params.temperature is not None: - api_params["temperature"] = params.temperature - if params.top_p is not None: - api_params["top_p"] = params.top_p - if params.top_k is not None: - api_params["top_k"] = params.top_k - if params.stop_sequences: - api_params["stop"] = params.stop_sequences + Returns: + Dict: Filtered parameters dictionary + """ + return {k: v for k, v in params.items() if k not in UNSUPPORTED_PARAMS} - return api_params + async def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + """Prepare parameters for the API request. - async def completion( + Args: + request: The completion request + + Returns: + dict: Prepared parameters for the API call + """ + # Get and process sampling options + sampling_options = get_sampling_options(request.sampling_params) + filtered_options = self._filter_unsupported_params(sampling_options) + + if "temperature" in filtered_options: + filtered_options["temperature"] = min( + max(filtered_options["temperature"], 0), 1 + ) + + input_dict = {} + if isinstance(request, ChatCompletionRequest): + input_dict["messages"] = [ + {"role": message.role, "content": message.content} + for message in request.messages + ] + + if request.tools and self._supports_function_calling(request.model): + input_dict["tools"] = [ + self._convert_tool_to_function(tool) + for tool in request.tools + ] + + if request.tool_choice == ToolChoice.auto: + input_dict["tool_choice"] = "auto" + elif request.tool_choice == ToolChoice.required: + input_dict["tool_choice"] = "required" + elif isinstance(request.tool_choice, str): + input_dict["tool_choice"] = { + "type": "function", + "function": {"name": request.tool_choice}, + } + else: + input_dict["prompt"] = request.content + + return { + "model": request.model, + **input_dict, + **filtered_options, + "stream": request.stream, + } + + async def _handle_sambanova_error(self, e: Exception) -> None: + """Handle SambaNova specific API errors with detailed messages. + + Args: + e: The exception to handle + + Raises: + ValueError: For client errors + RuntimeError: For server errors + """ + error_msg = str(e) + error_data = {} + + try: + if hasattr(e, "response"): + error_data = e.response.json().get("error", {}) + except Exception: + pass + + error_code = error_data.get("code", "") + error_message = error_data.get("message", error_msg) + error_param = error_data.get("param", "") + + if "401" in error_msg or error_code == SambanovaErrorCode.INVALID_AUTH: + raise ValueError("Invalid API key or unauthorized access") from e + + elif ( + "408" in error_msg + or error_code == SambanovaErrorCode.REQUEST_TIMEOUT + ): + raise ValueError( + "Request timed out. Consider upgrading to a higher tier offering" + ) from e + + elif ( + "429" in error_msg + or error_code == SambanovaErrorCode.INSUFFICIENT_QUOTA + ): + raise ValueError( + "Rate limit exceeded. Consider upgrading to a higher tier offering" + ) from e + + elif "400" in error_msg: + if error_code == SambanovaErrorCode.CONTEXT_LENGTH_EXCEEDED: + raise ValueError( + "Total number of input and output tokens exceeds model's context length" + ) from e + + elif error_code == SambanovaErrorCode.INVALID_TYPE: + raise ValueError( + f"Invalid parameter type for {error_param}: {error_message}" + ) from e + + elif error_code in ( + SambanovaErrorCode.VALUE_ABOVE_MAX, + SambanovaErrorCode.VALUE_BELOW_MIN, + SambanovaErrorCode.INTEGER_ABOVE_MAX, + ): + raise ValueError( + f"Invalid value for {error_param}: {error_message}" + ) from e + + elif error_code == SambanovaErrorCode.MODEL_NOT_FOUND: + raise ValueError(f"Model not found: {error_message}") from e + + else: + raise ValueError(f"Bad request: {error_message}") from e + + raise RuntimeError(f"SambaNova API error: {error_message}") from e + + def _supports_function_calling(self, model: str) -> bool: + """Check if the model supports function calling. + + Args: + model: Model name to check + + Returns: + bool: True if model supports function calling + """ + return any( + model.startswith(supported) for supported in FUNCTION_CALLING_MODELS + ) + + def _convert_tool_to_function(self, tool: ToolDefinition) -> dict: + """Convert a ToolDefinition to SambaNova function format. + + Args: + tool: Tool definition to convert + + Returns: + dict: Function definition in SambaNova format + """ + return { + "type": "function", + "function": { + "name": tool.tool_name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": { + name: { + "type": param.param_type, + "description": param.description, + } + for name, param in tool.parameters.items() + }, + "required": list(tool.parameters.keys()), + }, + }, + } + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: + try: + params = await self._get_params(request) + response = await self.client.chat.completions.create(**params) + + if ( + self._supports_function_calling(request.model) + and response.choices[0].message.tool_calls + ): + tool_call = response.choices[0].message.tool_calls[0] + choice = OpenAICompatCompletionChoice( + finish_reason=response.choices[0].finish_reason, + text="", + tool_calls=[ + { + "tool_name": tool_call.function.name, + "arguments": tool_call.function.arguments or "", + } + ], + ) + else: + choice = OpenAICompatCompletionChoice( + finish_reason=response.choices[0].finish_reason, + text=response.choices[0].message.content or "", + tool_calls=[], + ) + + compat_response = OpenAICompatCompletionResponse(choices=[choice]) + return process_chat_completion_response( + compat_response, self.formatter + ) + + except Exception as e: + await self._handle_sambanova_error(e) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncIterator[ChatCompletionResponseStreamChunk]: + try: + params = await self._get_params(request) + stream = await self.client.chat.completions.create(**params) + + async def _to_async_generator(): + async for chunk in stream: + if ( + self._supports_function_calling(request.model) + and chunk.choices[0].delta.tool_calls + ): + tool_call = chunk.choices[0].delta.tool_calls[0] + choice = OpenAICompatCompletionChoice( + finish_reason=chunk.choices[0].finish_reason, + text="", + tool_calls=[ + { + "tool_name": tool_call.function.name, + "arguments": tool_call.function.arguments + or "", + } + ] + if tool_call.function + else None, + ) + else: + choice = OpenAICompatCompletionChoice( + finish_reason=chunk.choices[0].finish_reason, + text=chunk.choices[0].delta.content or "", + tool_calls=[], + ) + yield OpenAICompatCompletionResponse(choices=[choice]) + + async for chunk in process_chat_completion_stream_response( + _to_async_generator(), self.formatter + ): + yield chunk + + except Exception as e: + await self._handle_sambanova_error(e) + + def completion( self, model_id: str, content: InterleavedTextMedia, @@ -98,92 +440,10 @@ class SambanovaInferenceAdapter( response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> AsyncGenerator: - model = await self.model_store.get_model(model_id) - request = CompletionRequest( - model=model.provider_resource_id, - content=content, - sampling_params=sampling_params, - stream=stream, - logprobs=logprobs, - ) - if stream: - return self._stream_completion(request) - else: - return await self._nonstream_completion(request) - - async def _get_params( - self, request: Union[ChatCompletionRequest, CompletionRequest] - ) -> dict: - sampling_options = get_sampling_options(request.sampling_params) - - input_dict = {} - if isinstance(request, ChatCompletionRequest): - if isinstance(request.messages[0].content, list): - raise NotImplementedError("Media content not supported for SambaNova") - input_dict["messages"] = self._convert_messages_to_api_format( - request.messages - ) - else: - input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) - - return { - "model": request.model, - **input_dict, - **sampling_options, - "stream": request.stream, - } - - async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = await self._get_params(request) - try: - response = await self.client.post( - "/completions", - json=params, - headers={"Authorization": f"Bearer {self._get_api_key()}"}, - ) - response.raise_for_status() - data = response.json() - - choice = OpenAICompatCompletionChoice( - finish_reason=data.get("choices", [{}])[0].get("finish_reason"), - text=data.get("choices", [{}])[0].get("text", ""), - ) - response = OpenAICompatCompletionResponse( - choices=[choice], - ) - return process_completion_response(response, self.formatter) - except httpx.HTTPError as e: - await self._handle_api_error(e) - - async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = await self._get_params(request) - - async def _to_async_generator(): - try: - async with self.client.stream( - "POST", - "/completions", - json=params, - headers={"Authorization": f"Bearer {self._get_api_key()}"}, - ) as response: - response.raise_for_status() - async for line in response.aiter_lines(): - if line: - data = httpx.loads(line) - choice = OpenAICompatCompletionChoice( - finish_reason=data.get("choices", [{}])[0].get( - "finish_reason" - ), - text=data.get("choices", [{}])[0].get("text", ""), - ) - yield OpenAICompatCompletionResponse(choices=[choice]) - except httpx.HTTPError as e: - await self._handle_api_error(e) - - stream = _to_async_generator() - async for chunk in process_completion_stream_response(stream, self.formatter): - yield chunk + ) -> Union[ + CompletionResponse, AsyncIterator[CompletionResponseStreamChunk] + ]: + raise NotImplementedError("SambaNova does not support text completion") async def chat_completion( self, @@ -197,7 +457,37 @@ class SambanovaInferenceAdapter( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + """Handle chat completion requests. + + Args: + model_id: The model identifier + messages: List of chat messages + sampling_params: Parameters for text generation + tools: Tool definitions (supported only for specific models) + tool_choice: Tool choice option + tool_prompt_format: Tool prompt format + response_format: Response format (not supported) + stream: Whether to stream the response + logprobs: Log probability config (not supported) + + Returns: + AsyncGenerator: The completion response + + Raises: + ValueError: If function calling is requested for unsupported model + """ model = await self.model_store.get_model(model_id) + + # Raise error for tool usage with unsupported models + if tools and not self._supports_function_calling( + model.provider_resource_id + ): + raise ValueError( + f"Function calling is not supported for model {model.provider_resource_id}. " + f"Only the following models support function calling: " + f"{', '.join(FUNCTION_CALLING_MODELS)}" + ) + request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, @@ -213,78 +503,14 @@ class SambanovaInferenceAdapter( else: return await self._nonstream_chat_completion(request) - async def _nonstream_chat_completion( - self, request: ChatCompletionRequest - ) -> AsyncGenerator: - params = await self._get_params(request) - try: - response = await self.client.post( - "/chat/completions", - json=params, - headers={"Authorization": f"Bearer {self._get_api_key()}"}, - ) - response.raise_for_status() - data = response.json() - - choice = OpenAICompatCompletionChoice( - finish_reason=data.get("choices", [{}])[0].get("finish_reason"), - text=data.get("choices", [{}])[0].get("message", {}).get("content", ""), - ) - response = OpenAICompatCompletionResponse(choices=[choice]) - return process_chat_completion_response(response, self.formatter) - except httpx.HTTPError as e: - await self._handle_api_error(e) - - async def _stream_chat_completion( - self, request: ChatCompletionRequest - ) -> AsyncGenerator: - params = await self._get_params(request) - - async def _to_async_generator(): - try: - async with self.client.stream( - "POST", - "/chat/completions", - json=params, - headers={"Authorization": f"Bearer {self._get_api_key()}"}, - ) as response: - response.raise_for_status() - async for line in response.aiter_lines(): - if line: - data = httpx.loads(line) - choice = OpenAICompatCompletionChoice( - finish_reason=data.get("choices", [{}])[0].get( - "finish_reason" - ), - text=data.get("choices", [{}])[0] - .get("message", {}) - .get("content", ""), - ) - yield OpenAICompatCompletionResponse(choices=[choice]) - except httpx.HTTPError as e: - await self._handle_api_error(e) - - stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response( - stream, self.formatter - ): - yield chunk - - async def _handle_api_error(self, e: httpx.HTTPError) -> None: - if e.response.status_code in (401, 403): - raise ValueError("Invalid API key or unauthorized access") from e - elif e.response.status_code == 429: - raise ValueError("Rate limit exceeded") from e - elif e.response.status_code == 400: - error_data = e.response.json() - raise ValueError( - f"Bad request: {error_data.get('error', {}).get('message', 'Unknown error')}" - ) from e - raise RuntimeError(f"SambaNova API error: {str(e)}") from e - async def embeddings( self, model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: + """Embeddings are not supported. + + Raises: + NotImplementedError: Always raised as this feature is not supported + """ raise NotImplementedError("Embeddings not supported for SambaNova") diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 2007818e5..ee0ff2c93 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -22,6 +22,7 @@ from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig +from llama_stack.providers.remote.inference.sambanova import SambanovaImplConfig from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture @@ -155,6 +156,21 @@ def inference_nvidia() -> ProviderFixture: ], ) +@pytest.fixture(scope="session") +def inference_sambanova() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="sambanova", + provider_type="remote::sambanova", + config=SambanovaImplConfig().model_dump(), + ) + ], + provider_data=dict( + sambanova_api_key=get_env_or_fail("SAMBANOVA_API_KEY"), + ), + ) + def get_model_short_name(model_name: str) -> str: """Convert model name to a short test identifier. @@ -190,6 +206,7 @@ INFERENCE_FIXTURES = [ "remote", "bedrock", "nvidia", + "sambanova", ] diff --git a/llama_stack/templates/sambanova/__init__.py b/llama_stack/templates/sambanova/__init__.py index 3c48ebc85..30209fb7f 100644 --- a/llama_stack/templates/sambanova/__init__.py +++ b/llama_stack/templates/sambanova/__init__.py @@ -1 +1,7 @@ -from .sambanova import get_distribution_template # noqa: F401 \ No newline at end of file +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .sambanova import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/sambanova/build.yaml b/llama_stack/templates/sambanova/build.yaml index 30f10e0ed..347cbf34a 100644 --- a/llama_stack/templates/sambanova/build.yaml +++ b/llama_stack/templates/sambanova/build.yaml @@ -16,4 +16,4 @@ distribution_spec: - inline::meta-reference telemetry: - inline::meta-reference -image_type: conda \ No newline at end of file +image_type: conda diff --git a/llama_stack/templates/sambanova/doc_template.md b/llama_stack/templates/sambanova/doc_template.md new file mode 100644 index 000000000..00d9f19c8 --- /dev/null +++ b/llama_stack/templates/sambanova/doc_template.md @@ -0,0 +1,66 @@ +--- +orphan: true +--- +# SambaNova Distribution + +```{toctree} +:maxdepth: 2 +:hidden: + +self +``` + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +{% if run_config_env_vars %} +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + +{% if default_models %} +### Models + +The following models are available by default: + +{% for model in default_models %} +- `{{ model.model_id }}` +{% endfor %} +{% endif %} + +### Prerequisite: API Keys + +Make sure you have access to a SambaNova API Key. You can get one by contacting SambaNova Systems. + +## Running Llama Stack with SambaNova + +You can do this via Conda (build code) or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + llamastack/distribution-{{ name }} \ + --port $LLAMA_STACK_PORT \ + --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY +``` + +### Via Conda + +```bash +llama stack build --template {{ name }} --image-type conda +llama stack run ./run.yaml \ + --port $LLAMA_STACK_PORT \ + --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY +``` \ No newline at end of file diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 3c43368f0..e24deb376 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -45,15 +45,35 @@ metadata_store: db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db models: - metadata: {} - model_id: Meta-Llama-3.1-8B-Instruct + model_id: meta-llama/Llama-3.2-1B-Instruct + provider_id: null + provider_model_id: Meta-Llama-3.2-1B-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: null + provider_model_id: Meta-Llama-3.2-3B-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: null + provider_model_id: Llama-3.2-11B-Vision-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: null + provider_model_id: Llama-3.2-90B-Vision-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct provider_id: null provider_model_id: Meta-Llama-3.1-8B-Instruct -shields: -- params: null - shield_id: meta-llama/Llama-Guard-3-8B +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct provider_id: null - provider_shield_id: null + provider_model_id: Meta-Llama-3.1-70B-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.1-405B-Instruct + provider_id: null + provider_model_id: Meta-Llama-3.1-405B-Instruct +shields: [] memory_banks: [] datasets: [] scoring_fns: [] -eval_tasks: [] \ No newline at end of file +eval_tasks: [] diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index e93aee514..b0c76be68 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -1,11 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + from pathlib import Path from llama_models.sku_list import all_registered_models from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput from llama_stack.providers.remote.inference.sambanova import SambanovaImplConfig -from llama_stack.providers.remote.inference.sambanova.sambanova import MODEL_ALIASES -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings +from llama_stack.providers.remote.inference.sambanova.sambanova import ( + MODEL_ALIASES, +) +from llama_stack.templates.template import ( + DistributionTemplate, + RunConfigSettings, +) def get_distribution_template() -> DistributionTemplate: @@ -26,6 +37,7 @@ def get_distribution_template() -> DistributionTemplate: core_model_to_hf_repo = { m.descriptor(): m.huggingface_repo for m in all_registered_models() } + default_models = [ ModelInput( model_id=core_model_to_hf_repo[m.llama_model], @@ -48,7 +60,9 @@ def get_distribution_template() -> DistributionTemplate: "inference": [inference_provider], }, default_models=default_models, - default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + default_shields=[ + ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B") + ], ), }, run_config_env_vars={ @@ -58,7 +72,7 @@ def get_distribution_template() -> DistributionTemplate: ), "SAMBANOVA_API_KEY": ( "", - "SambaNova API Key", + "SambaNova API Key for authentication", ), }, )