From 22dc684da6501e1007043fbcc18765216613fb77 Mon Sep 17 00:00:00 2001 From: snova-edwardm Date: Thu, 23 Jan 2025 12:20:28 -0800 Subject: [PATCH] Sambanova inference provider (#555) # What does this PR do? This PR adds SambaNova as one of the Provider - Add SambaNova as a provider ## Test Plan Test the functional command ``` pytest -s -v --providers inference=sambanova llama_stack/providers/tests/inference/test_embeddings.py llama_stack/providers/tests/inference/test_prompt_adapter.py llama_stack/providers/tests/inference/test_text_inference.py llama_stack/providers/tests/inference/test_vision_inference.py --env SAMBANOVA_API_KEY= ``` Test the distribution template: ``` # Docker LLAMA_STACK_PORT=5001 docker run -it -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ llamastack/distribution-sambanova \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY # Conda llama stack build --template sambanova --image-type conda llama stack run ./run.yaml \ --port $LLAMA_STACK_PORT \ --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY ``` ## Source [SambaNova API Documentation](https://cloud.sambanova.ai/apis) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [Y] Ran pre-commit to handle lint / formatting issues. - [Y] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [Y] Updated relevant documentation. - [Y ] Wrote necessary unit or integration tests. --------- Co-authored-by: Ashwin Bharambe --- distributions/sambanova/build.yaml | 19 + distributions/sambanova/compose.yaml | 16 + distributions/sambanova/run.yaml | 83 +++++ docs/source/concepts/index.md | 2 +- .../self_hosted_distro/sambanova.md | 74 ++++ docs/source/index.md | 1 + llama_stack/distribution/ui/modules/api.py | 1 + llama_stack/providers/registry/inference.py | 11 + .../remote/inference/sambanova/__init__.py | 23 ++ .../remote/inference/sambanova/config.py | 29 ++ .../remote/inference/sambanova/sambanova.py | 333 ++++++++++++++++++ .../providers/tests/inference/fixtures.py | 19 + .../inference/test_model_registration.py | 2 +- .../tests/inference/test_text_inference.py | 9 + .../tests/inference/test_vision_inference.py | 2 + llama_stack/templates/sambanova/__init__.py | 7 + llama_stack/templates/sambanova/build.yaml | 19 + .../templates/sambanova/doc_template.md | 68 ++++ llama_stack/templates/sambanova/run.yaml | 83 +++++ llama_stack/templates/sambanova/sambanova.py | 71 ++++ 20 files changed, 870 insertions(+), 2 deletions(-) create mode 100644 distributions/sambanova/build.yaml create mode 100644 distributions/sambanova/compose.yaml create mode 100644 distributions/sambanova/run.yaml create mode 100644 docs/source/distributions/self_hosted_distro/sambanova.md create mode 100644 llama_stack/providers/remote/inference/sambanova/__init__.py create mode 100644 llama_stack/providers/remote/inference/sambanova/config.py create mode 100644 llama_stack/providers/remote/inference/sambanova/sambanova.py create mode 100644 llama_stack/templates/sambanova/__init__.py create mode 100644 llama_stack/templates/sambanova/build.yaml create mode 100644 llama_stack/templates/sambanova/doc_template.md create mode 100644 llama_stack/templates/sambanova/run.yaml create mode 100644 llama_stack/templates/sambanova/sambanova.py diff --git a/distributions/sambanova/build.yaml b/distributions/sambanova/build.yaml new file mode 100644 index 000000000..d6da478d1 --- /dev/null +++ b/distributions/sambanova/build.yaml @@ -0,0 +1,19 @@ +version: '2' +name: sambanova +distribution_spec: + description: Use SambaNova.AI for running LLM inference + docker_image: null + providers: + inference: + - remote::sambanova + memory: + - inline::faiss + - remote::chromadb + - remote::pgvector + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference +image_type: conda diff --git a/distributions/sambanova/compose.yaml b/distributions/sambanova/compose.yaml new file mode 100644 index 000000000..58b9fb1ef --- /dev/null +++ b/distributions/sambanova/compose.yaml @@ -0,0 +1,16 @@ +services: + llamastack: + image: llamastack/distribution-sambanova + network_mode: "host" + volumes: + - ~/.llama:/root/.llama + - ./run.yaml:/root/llamastack-run-sambanova.yaml + ports: + - "5000:5000" + entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-sambanova.yaml" + deploy: + restart_policy: + condition: on-failure + delay: 3s + max_attempts: 5 + window: 60s diff --git a/distributions/sambanova/run.yaml b/distributions/sambanova/run.yaml new file mode 100644 index 000000000..03c8ea44f --- /dev/null +++ b/distributions/sambanova/run.yaml @@ -0,0 +1,83 @@ +version: '2' +image_name: sambanova +docker_image: null +conda_env: sambanova +apis: +- agents +- inference +- memory +- safety +- telemetry +providers: + inference: + - provider_id: sambanova + provider_type: remote::sambanova + config: + url: https://api.sambanova.ai/v1/ + api_key: ${env.SAMBANOVA_API_KEY} + memory: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} +metadata_store: + namespace: null + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db +models: +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: null + provider_model_id: Meta-Llama-3.1-8B-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: null + provider_model_id: Meta-Llama-3.1-70B-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.1-405B-Instruct + provider_id: null + provider_model_id: Meta-Llama-3.1-405B-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.2-1B-Instruct + provider_id: null + provider_model_id: Meta-Llama-3.2-1B-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: null + provider_model_id: Meta-Llama-3.2-3B-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: null + provider_model_id: Llama-3.2-11B-Vision-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: null + provider_model_id: Llama-3.2-90B-Vision-Instruct +shields: +- params: null + shield_id: meta-llama/Llama-Guard-3-8B + provider_id: null + provider_shield_id: null +memory_banks: [] +datasets: [] +scoring_fns: [] +eval_tasks: [] diff --git a/docs/source/concepts/index.md b/docs/source/concepts/index.md index 02e54d839..f638ba8d0 100644 --- a/docs/source/concepts/index.md +++ b/docs/source/concepts/index.md @@ -24,7 +24,7 @@ We are working on adding a few more APIs to complete the application lifecycle. ## API Providers The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Obvious examples for these include -- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, etc.), +- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, SambaNova, etc.), - Vector databases (e.g., ChromaDB, Weaviate, Qdrant, etc.), - Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.) diff --git a/docs/source/distributions/self_hosted_distro/sambanova.md b/docs/source/distributions/self_hosted_distro/sambanova.md new file mode 100644 index 000000000..52d1cd962 --- /dev/null +++ b/docs/source/distributions/self_hosted_distro/sambanova.md @@ -0,0 +1,74 @@ +--- +orphan: true +--- +# SambaNova Distribution + +```{toctree} +:maxdepth: 2 +:hidden: + +self +``` + +The `llamastack/distribution-sambanova` distribution consists of the following provider configurations. + +| API | Provider(s) | +|-----|-------------| +| agents | `inline::meta-reference` | +| inference | `remote::sambanova` | +| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | +| safety | `inline::llama-guard` | +| telemetry | `inline::meta-reference` | + + +### Environment Variables + +The following environment variables can be configured: + +- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) +- `SAMBANOVA_API_KEY`: SambaNova.AI API Key (default: ``) + +### Models + +The following models are available by default: + +- `meta-llama/Llama-3.1-8B-Instruct` +- `meta-llama/Llama-3.1-70B-Instruct` +- `meta-llama/Llama-3.1-405B-Instruct` +- `meta-llama/Llama-3.2-1B-Instruct` +- `meta-llama/Llama-3.2-3B-Instruct` +- `meta-llama/Llama-3.2-11B-Vision-Instruct` +- `meta-llama/Llama-3.2-90B-Vision-Instruct` + + +### Prerequisite: API Keys + +Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaBova.ai](https://sambanova.ai/). + + +## Running Llama Stack with SambaNova + +You can do this via Conda (build code) or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + llamastack/distribution-sambanova \ + --port $LLAMA_STACK_PORT \ + --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY +``` + +### Via Conda + +```bash +llama stack build --template sambanova --image-type conda +llama stack run ./run.yaml \ + --port $LLAMA_STACK_PORT \ + --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY +``` diff --git a/docs/source/index.md b/docs/source/index.md index bc4666be3..77afd9d22 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -40,6 +40,7 @@ A number of "adapters" are available for some popular Inference and Memory (Vect | Fireworks | Hosted | Y | Y | Y | | | | AWS Bedrock | Hosted | | Y | | Y | | | Together | Hosted | Y | Y | | Y | | +| SambaNova | Hosted | | Y | | | | | Ollama | Single Node | | Y | | | | TGI | Hosted and Single Node | | Y | | | | NVIDIA NIM | Hosted and Single Node | | Y | | | diff --git a/llama_stack/distribution/ui/modules/api.py b/llama_stack/distribution/ui/modules/api.py index 70c7a0898..7d3367ba5 100644 --- a/llama_stack/distribution/ui/modules/api.py +++ b/llama_stack/distribution/ui/modules/api.py @@ -18,6 +18,7 @@ class LlamaStackApi: provider_data={ "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""), "together_api_key": os.environ.get("TOGETHER_API_KEY", ""), + "sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""), "openai_api_key": os.environ.get("OPENAI_API_KEY", ""), }, ) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 320c649d2..af2cb8e65 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -204,4 +204,15 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.inference.runpod.RunpodImplConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="sambanova", + pip_packages=[ + "openai", + ], + module="llama_stack.providers.remote.inference.sambanova", + config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", + ), + ), ] diff --git a/llama_stack/providers/remote/inference/sambanova/__init__.py b/llama_stack/providers/remote/inference/sambanova/__init__.py new file mode 100644 index 000000000..ab442066a --- /dev/null +++ b/llama_stack/providers/remote/inference/sambanova/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + +from .config import SambaNovaImplConfig +from .sambanova import SambaNovaInferenceAdapter + + +class SambaNovaProviderDataValidator(BaseModel): + sambanova_api_key: str + + +async def get_adapter_impl(config: SambaNovaImplConfig, _deps): + assert isinstance( + config, SambaNovaImplConfig + ), f"Unexpected config type: {type(config)}" + impl = SambaNovaInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/inference/sambanova/config.py b/llama_stack/providers/remote/inference/sambanova/config.py new file mode 100644 index 000000000..e7454404b --- /dev/null +++ b/llama_stack/providers/remote/inference/sambanova/config.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class SambaNovaImplConfig(BaseModel): + url: str = Field( + default="https://api.sambanova.ai/v1", + description="The URL for the SambaNova AI server", + ) + api_key: Optional[str] = Field( + default=None, + description="The SambaNova.ai API Key", + ) + + @classmethod + def sample_run_config(cls) -> Dict[str, Any]: + return { + "url": "https://api.sambanova.ai/v1", + "api_key": "${env.SAMBANOVA_API_KEY}", + } diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py new file mode 100644 index 000000000..9c203a8d0 --- /dev/null +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -0,0 +1,333 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import AsyncGenerator + +from llama_models.datatypes import CoreModelId, SamplingStrategy +from llama_models.llama3.api.chat_format import ChatFormat +from llama_models.llama3.api.tokenizer import Tokenizer +from openai import OpenAI + +from llama_stack.apis.common.content_types import ( + ImageContentItem, + InterleavedContent, + TextContentItem, +) +from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) +from llama_stack.providers.utils.inference.openai_compat import ( + process_chat_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + convert_image_content_to_url, +) + +from .config import SambaNovaImplConfig + +MODEL_ALIASES = [ + build_model_alias( + "Meta-Llama-3.1-8B-Instruct", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "Meta-Llama-3.1-70B-Instruct", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_model_alias( + "Meta-Llama-3.1-405B-Instruct", + CoreModelId.llama3_1_405b_instruct.value, + ), + build_model_alias( + "Meta-Llama-3.2-1B-Instruct", + CoreModelId.llama3_2_1b_instruct.value, + ), + build_model_alias( + "Meta-Llama-3.2-3B-Instruct", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_model_alias( + "Llama-3.2-11B-Vision-Instruct", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_model_alias( + "Llama-3.2-90B-Vision-Instruct", + CoreModelId.llama3_2_90b_vision_instruct.value, + ), +] + + +class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): + def __init__(self, config: SambaNovaImplConfig) -> None: + ModelRegistryHelper.__init__( + self, + model_aliases=MODEL_ALIASES, + ) + + self.config = config + self.formatter = ChatFormat(Tokenizer.get_instance()) + + async def initialize(self) -> None: + return + + async def shutdown(self) -> None: + pass + + def _get_client(self) -> OpenAI: + return OpenAI(base_url=self.config.url, api_key=self.config.api_key) + + async def completion( + self, + model_id: str, + content: InterleavedContent, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + raise NotImplementedError() + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) + + request = ChatCompletionRequest( + model=model.provider_resource_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + request_sambanova = await self.convert_chat_completion_request(request) + + if stream: + return self._stream_chat_completion(request_sambanova) + else: + return await self._nonstream_chat_completion(request_sambanova) + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: + response = self._get_client().chat.completions.create(**request) + + choice = response.choices[0] + + result = ChatCompletionResponse( + completion_message=CompletionMessage( + content=choice.message.content or "", + stop_reason=self.convert_to_sambanova_finish_reason( + choice.finish_reason + ), + tool_calls=self.convert_to_sambanova_tool_calls( + choice.message.tool_calls + ), + ), + logprobs=None, + ) + + return result + + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: + async def _to_async_generator(): + streaming = self._get_client().chat.completions.create(**request) + for chunk in streaming: + yield chunk + + stream = _to_async_generator() + async for chunk in process_chat_completion_stream_response( + stream, self.formatter + ): + yield chunk + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedContent], + ) -> EmbeddingsResponse: + raise NotImplementedError() + + async def convert_chat_completion_request( + self, request: ChatCompletionRequest + ) -> dict: + compatible_request = self.convert_sampling_params(request.sampling_params) + compatible_request["model"] = request.model + compatible_request["messages"] = await self.convert_to_sambanova_messages( + request.messages + ) + compatible_request["stream"] = request.stream + compatible_request["logprobs"] = False + compatible_request["extra_headers"] = { + b"User-Agent": b"llama-stack: sambanova-inference-adapter", + } + compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools) + return compatible_request + + def convert_sampling_params( + self, sampling_params: SamplingParams, legacy: bool = False + ) -> dict: + params = {} + + if sampling_params: + params["frequency_penalty"] = sampling_params.repetition_penalty + + if sampling_params.max_tokens: + if legacy: + params["max_tokens"] = sampling_params.max_tokens + else: + params["max_completion_tokens"] = sampling_params.max_tokens + + if sampling_params.strategy == SamplingStrategy.top_p: + params["top_p"] = sampling_params.top_p + elif sampling_params.strategy == "top_k": + params["extra_body"]["top_k"] = sampling_params.top_k + elif sampling_params.strategy == "greedy": + params["temperature"] = sampling_params.temperature + + return params + + async def convert_to_sambanova_messages( + self, messages: List[Message] + ) -> List[dict]: + conversation = [] + for message in messages: + content = {} + + content["content"] = await self.convert_to_sambanova_content(message) + + if isinstance(message, UserMessage): + content["role"] = "user" + elif isinstance(message, CompletionMessage): + content["role"] = "assistant" + tools = [] + for tool_call in message.tool_calls: + tools.append( + { + "id": tool_call.call_id, + "function": { + "name": tool_call.name, + "arguments": json.dumps(tool_call.arguments), + }, + "type": "function", + } + ) + content["tool_calls"] = tools + elif isinstance(message, ToolResponseMessage): + content["role"] = "tool" + content["tool_call_id"] = message.call_id + elif isinstance(message, SystemMessage): + content["role"] = "system" + + conversation.append(content) + + return conversation + + async def convert_to_sambanova_content(self, message: Message) -> dict: + async def _convert_content(content) -> dict: + if isinstance(content, ImageContentItem): + url = await convert_image_content_to_url(content, download=True) + # A fix to make sure the call sucess. + components = url.split(";base64") + url = f"{components[0].lower()};base64{components[1]}" + return { + "type": "image_url", + "image_url": {"url": url}, + } + else: + text = content.text if isinstance(content, TextContentItem) else content + assert isinstance(text, str) + return {"type": "text", "text": text} + + if isinstance(message.content, list): + # If it is a list, the text content should be wrapped in dict + content = [await _convert_content(c) for c in message.content] + else: + content = message.content + + return content + + def convert_to_sambanova_tool(self, tools: List[ToolDefinition]) -> List[dict]: + if tools is None: + return tools + + compatiable_tools = [] + + for tool in tools: + properties = {} + compatiable_required = [] + if tool.parameters: + for tool_key, tool_param in tool.parameters.items(): + properties[tool_key] = {"type": tool_param.param_type} + if tool_param.description: + properties[tool_key]["description"] = tool_param.description + if tool_param.default: + properties[tool_key]["default"] = tool_param.default + if tool_param.required: + compatiable_required.append(tool_key) + + compatiable_tool = { + "type": "function", + "function": { + "name": tool.tool_name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": properties, + "required": compatiable_required, + }, + }, + } + + compatiable_tools.append(compatiable_tool) + + if len(compatiable_tools) > 0: + return compatiable_tools + return None + + def convert_to_sambanova_finish_reason(self, finish_reason: str) -> StopReason: + return { + "stop": StopReason.end_of_turn, + "length": StopReason.out_of_tokens, + "tool_calls": StopReason.end_of_message, + }.get(finish_reason, StopReason.end_of_turn) + + def convert_to_sambanova_tool_calls( + self, + tool_calls, + ) -> List[ToolCall]: + if not tool_calls: + return [] + + for call in tool_calls: + call_function_arguments = json.loads(call.function.arguments) + + compitable_tool_calls = [ + ToolCall( + call_id=call.id, + tool_name=call.function.name, + arguments=call_function_arguments, + ) + for call in tool_calls + ] + + return compitable_tool_calls diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 0767e940f..331898a7f 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -23,6 +23,7 @@ from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.groq import GroqConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig +from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.tgi import TGIImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig @@ -232,6 +233,23 @@ def inference_tgi() -> ProviderFixture: @pytest.fixture(scope="session") +def inference_sambanova() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="sambanova", + provider_type="remote::sambanova", + config=SambaNovaImplConfig( + api_key=get_env_or_fail("SAMBANOVA_API_KEY"), + ).model_dump(), + ) + ], + provider_data=dict( + sambanova_api_key=get_env_or_fail("SAMBANOVA_API_KEY"), + ), + ) + + def inference_sentence_transformers() -> ProviderFixture: return ProviderFixture( providers=[ @@ -282,6 +300,7 @@ INFERENCE_FIXTURES = [ "cerebras", "nvidia", "tgi", + "sambanova", ] diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 3cd7b2496..96a34ec0e 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -59,7 +59,7 @@ class TestModelRegistration: }, ) - with pytest.raises(AssertionError) as exc_info: + with pytest.raises(ValueError) as exc_info: await models_impl.register_model( model_id="custom-model-2", metadata={ diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index e1052c289..7201fdc4a 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -385,6 +385,12 @@ class TestInference: # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") + if provider.__provider_spec__.provider_type == "remote::sambanova" and ( + "-1B-" in inference_model or "-3B-" in inference_model + ): + # TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling for 1B/ 3B + pytest.skip("Sambanova's tool calling for lightweight models don't work") + messages = sample_messages + [ UserMessage( content="What's the weather like in San Francisco?", @@ -431,6 +437,9 @@ class TestInference: ): # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") + if provider.__provider_spec__.provider_type == "remote::sambanova": + # TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling under streaming is supported (we are working on it) + pytest.skip("Sambanova's tool calling for streaming doesn't work") messages = sample_messages + [ UserMessage( diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index 100a70236..fba7cefde 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -59,6 +59,7 @@ class TestVisionModelInference: "remote::fireworks", "remote::ollama", "remote::vllm", + "remote::sambanova", ): pytest.skip( "Other inference providers don't support vision chat completion() yet" @@ -98,6 +99,7 @@ class TestVisionModelInference: "remote::fireworks", "remote::ollama", "remote::vllm", + "remote::sambanova", ): pytest.skip( "Other inference providers don't support vision chat completion() yet" diff --git a/llama_stack/templates/sambanova/__init__.py b/llama_stack/templates/sambanova/__init__.py new file mode 100644 index 000000000..30209fb7f --- /dev/null +++ b/llama_stack/templates/sambanova/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .sambanova import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/sambanova/build.yaml b/llama_stack/templates/sambanova/build.yaml new file mode 100644 index 000000000..d6da478d1 --- /dev/null +++ b/llama_stack/templates/sambanova/build.yaml @@ -0,0 +1,19 @@ +version: '2' +name: sambanova +distribution_spec: + description: Use SambaNova.AI for running LLM inference + docker_image: null + providers: + inference: + - remote::sambanova + memory: + - inline::faiss + - remote::chromadb + - remote::pgvector + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference +image_type: conda diff --git a/llama_stack/templates/sambanova/doc_template.md b/llama_stack/templates/sambanova/doc_template.md new file mode 100644 index 000000000..4af4718e5 --- /dev/null +++ b/llama_stack/templates/sambanova/doc_template.md @@ -0,0 +1,68 @@ +--- +orphan: true +--- +# SambaNova Distribution + +```{toctree} +:maxdepth: 2 +:hidden: + +self +``` + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +{% if run_config_env_vars %} +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + +{% if default_models %} +### Models + +The following models are available by default: + +{% for model in default_models %} +- `{{ model.model_id }} ({{ model.provider_model_id }})` +{% endfor %} +{% endif %} + + +### Prerequisite: API Keys + +Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaBova.ai](https://sambanova.ai/). + + +## Running Llama Stack with SambaNova + +You can do this via Conda (build code) or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + llamastack/distribution-{{ name }} \ + --port $LLAMA_STACK_PORT \ + --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY +``` + +### Via Conda + +```bash +llama stack build --template sambanova --image-type conda +llama stack run ./run.yaml \ + --port $LLAMA_STACK_PORT \ + --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY +``` diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml new file mode 100644 index 000000000..03c8ea44f --- /dev/null +++ b/llama_stack/templates/sambanova/run.yaml @@ -0,0 +1,83 @@ +version: '2' +image_name: sambanova +docker_image: null +conda_env: sambanova +apis: +- agents +- inference +- memory +- safety +- telemetry +providers: + inference: + - provider_id: sambanova + provider_type: remote::sambanova + config: + url: https://api.sambanova.ai/v1/ + api_key: ${env.SAMBANOVA_API_KEY} + memory: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} +metadata_store: + namespace: null + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db +models: +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: null + provider_model_id: Meta-Llama-3.1-8B-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: null + provider_model_id: Meta-Llama-3.1-70B-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.1-405B-Instruct + provider_id: null + provider_model_id: Meta-Llama-3.1-405B-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.2-1B-Instruct + provider_id: null + provider_model_id: Meta-Llama-3.2-1B-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: null + provider_model_id: Meta-Llama-3.2-3B-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: null + provider_model_id: Llama-3.2-11B-Vision-Instruct +- metadata: {} + model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: null + provider_model_id: Llama-3.2-90B-Vision-Instruct +shields: +- params: null + shield_id: meta-llama/Llama-Guard-3-8B + provider_id: null + provider_shield_id: null +memory_banks: [] +datasets: [] +scoring_fns: [] +eval_tasks: [] diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py new file mode 100644 index 000000000..8c231617b --- /dev/null +++ b/llama_stack/templates/sambanova/sambanova.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_models.sku_list import all_registered_models + +from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig +from llama_stack.providers.remote.inference.sambanova.sambanova import MODEL_ALIASES + +from llama_stack.templates.template import DistributionTemplate, RunConfigSettings + + +def get_distribution_template() -> DistributionTemplate: + providers = { + "inference": ["remote::sambanova"], + "memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + } + + inference_provider = Provider( + provider_id="sambanova", + provider_type="remote::sambanova", + config=SambaNovaImplConfig.sample_run_config(), + ) + + core_model_to_hf_repo = { + m.descriptor(): m.huggingface_repo for m in all_registered_models() + } + default_models = [ + ModelInput( + model_id=core_model_to_hf_repo[m.llama_model], + provider_model_id=m.provider_model_id, + ) + for m in MODEL_ALIASES + ] + + return DistributionTemplate( + name="sambanova", + distro_type="self_hosted", + description="Use SambaNova.AI for running LLM inference", + docker_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + default_models=default_models, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider], + }, + default_models=default_models, + default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + ), + }, + run_config_env_vars={ + "LLAMASTACK_PORT": ( + "5001", + "Port for the Llama Stack distribution server", + ), + "SAMBANOVA_API_KEY": ( + "", + "SambaNova.AI API Key", + ), + }, + )