Sambanova inference provider (#555)

# What does this PR do?

This PR adds SambaNova as one of the Provider

- Add SambaNova as a provider

## Test Plan
Test the functional command
```
pytest -s -v --providers inference=sambanova llama_stack/providers/tests/inference/test_embeddings.py llama_stack/providers/tests/inference/test_prompt_adapter.py llama_stack/providers/tests/inference/test_text_inference.py llama_stack/providers/tests/inference/test_vision_inference.py --env SAMBANOVA_API_KEY=<sambanova-api-key>
```

Test the distribution template:
```
# Docker
LLAMA_STACK_PORT=5001
docker run -it -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
  llamastack/distribution-sambanova \
  --port $LLAMA_STACK_PORT \
  --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY

# Conda
llama stack build --template sambanova --image-type conda
llama stack run ./run.yaml \
  --port $LLAMA_STACK_PORT \
  --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```

## Source
[SambaNova API Documentation](https://cloud.sambanova.ai/apis)

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [Y] Ran pre-commit to handle lint / formatting issues.
- [Y] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [Y] Updated relevant documentation.
- [Y ] Wrote necessary unit or integration tests.

---------

Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
snova-edwardm 2025-01-23 12:20:28 -08:00 committed by GitHub
parent e2b5456e48
commit 22dc684da6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 870 additions and 2 deletions

View file

@ -0,0 +1,19 @@
version: '2'
name: sambanova
distribution_spec:
description: Use SambaNova.AI for running LLM inference
docker_image: null
providers:
inference:
- remote::sambanova
memory:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
image_type: conda

View file

@ -0,0 +1,16 @@
services:
llamastack:
image: llamastack/distribution-sambanova
network_mode: "host"
volumes:
- ~/.llama:/root/.llama
- ./run.yaml:/root/llamastack-run-sambanova.yaml
ports:
- "5000:5000"
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-sambanova.yaml"
deploy:
restart_policy:
condition: on-failure
delay: 3s
max_attempts: 5
window: 60s

View file

@ -0,0 +1,83 @@
version: '2'
image_name: sambanova
docker_image: null
conda_env: sambanova
apis:
- agents
- inference
- memory
- safety
- telemetry
providers:
inference:
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1/
api_key: ${env.SAMBANOVA_API_KEY}
memory:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-8B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-70B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-405B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.2-1B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.2-3B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: null
provider_model_id: Llama-3.2-11B-Vision-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: null
provider_model_id: Llama-3.2-90B-Vision-Instruct
shields:
- params: null
shield_id: meta-llama/Llama-Guard-3-8B
provider_id: null
provider_shield_id: null
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []

View file

@ -24,7 +24,7 @@ We are working on adding a few more APIs to complete the application lifecycle.
## API Providers ## API Providers
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Obvious examples for these include The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Obvious examples for these include
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, etc.), - LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, SambaNova, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, etc.), - Vector databases (e.g., ChromaDB, Weaviate, Qdrant, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.) - Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)

View file

@ -0,0 +1,74 @@
---
orphan: true
---
# SambaNova Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-sambanova` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| inference | `remote::sambanova` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `SAMBANOVA_API_KEY`: SambaNova.AI API Key (default: ``)
### Models
The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct`
- `meta-llama/Llama-3.1-70B-Instruct`
- `meta-llama/Llama-3.1-405B-Instruct`
- `meta-llama/Llama-3.2-1B-Instruct`
- `meta-llama/Llama-3.2-3B-Instruct`
- `meta-llama/Llama-3.2-11B-Vision-Instruct`
- `meta-llama/Llama-3.2-90B-Vision-Instruct`
### Prerequisite: API Keys
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaBova.ai](https://sambanova.ai/).
## Running Llama Stack with SambaNova
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-sambanova \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```
### Via Conda
```bash
llama stack build --template sambanova --image-type conda
llama stack run ./run.yaml \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```

View file

@ -40,6 +40,7 @@ A number of "adapters" are available for some popular Inference and Memory (Vect
| Fireworks | Hosted | Y | Y | Y | | | | Fireworks | Hosted | Y | Y | Y | | |
| AWS Bedrock | Hosted | | Y | | Y | | | AWS Bedrock | Hosted | | Y | | Y | |
| Together | Hosted | Y | Y | | Y | | | Together | Hosted | Y | Y | | Y | |
| SambaNova | Hosted | | Y | | | |
| Ollama | Single Node | | Y | | | | Ollama | Single Node | | Y | | |
| TGI | Hosted and Single Node | | Y | | | | TGI | Hosted and Single Node | | Y | | |
| NVIDIA NIM | Hosted and Single Node | | Y | | | | NVIDIA NIM | Hosted and Single Node | | Y | | |

View file

@ -18,6 +18,7 @@ class LlamaStackApi:
provider_data={ provider_data={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""), "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""), "together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""), "openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
}, },
) )

View file

@ -204,4 +204,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.inference.runpod.RunpodImplConfig", config_class="llama_stack.providers.adapters.inference.runpod.RunpodImplConfig",
), ),
), ),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=[
"openai",
],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
),
),
] ]

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import SambaNovaImplConfig
from .sambanova import SambaNovaInferenceAdapter
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: str
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
assert isinstance(
config, SambaNovaImplConfig
), f"Unexpected config type: {type(config)}"
impl = SambaNovaInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class SambaNovaImplConfig(BaseModel):
url: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: Optional[str] = Field(
default=None,
description="The SambaNova.ai API Key",
)
@classmethod
def sample_run_config(cls) -> Dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": "${env.SAMBANOVA_API_KEY}",
}

View file

@ -0,0 +1,333 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import AsyncGenerator
from llama_models.datatypes import CoreModelId, SamplingStrategy
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
TextContentItem,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)
from .config import SambaNovaImplConfig
MODEL_ALIASES = [
build_model_alias(
"Meta-Llama-3.1-8B-Instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.1-70B-Instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.1-405B-Instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.2-1B-Instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.2-3B-Instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"Llama-3.2-11B-Vision-Instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"Llama-3.2-90B-Vision-Instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
]
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: SambaNovaImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=MODEL_ALIASES,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
def _get_client(self) -> OpenAI:
return OpenAI(base_url=self.config.url, api_key=self.config.api_key)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
request_sambanova = await self.convert_chat_completion_request(request)
if stream:
return self._stream_chat_completion(request_sambanova)
else:
return await self._nonstream_chat_completion(request_sambanova)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
response = self._get_client().chat.completions.create(**request)
choice = response.choices[0]
result = ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "",
stop_reason=self.convert_to_sambanova_finish_reason(
choice.finish_reason
),
tool_calls=self.convert_to_sambanova_tool_calls(
choice.message.tool_calls
),
),
logprobs=None,
)
return result
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async def _to_async_generator():
streaming = self._get_client().chat.completions.create(**request)
for chunk in streaming:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def convert_chat_completion_request(
self, request: ChatCompletionRequest
) -> dict:
compatible_request = self.convert_sampling_params(request.sampling_params)
compatible_request["model"] = request.model
compatible_request["messages"] = await self.convert_to_sambanova_messages(
request.messages
)
compatible_request["stream"] = request.stream
compatible_request["logprobs"] = False
compatible_request["extra_headers"] = {
b"User-Agent": b"llama-stack: sambanova-inference-adapter",
}
compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools)
return compatible_request
def convert_sampling_params(
self, sampling_params: SamplingParams, legacy: bool = False
) -> dict:
params = {}
if sampling_params:
params["frequency_penalty"] = sampling_params.repetition_penalty
if sampling_params.max_tokens:
if legacy:
params["max_tokens"] = sampling_params.max_tokens
else:
params["max_completion_tokens"] = sampling_params.max_tokens
if sampling_params.strategy == SamplingStrategy.top_p:
params["top_p"] = sampling_params.top_p
elif sampling_params.strategy == "top_k":
params["extra_body"]["top_k"] = sampling_params.top_k
elif sampling_params.strategy == "greedy":
params["temperature"] = sampling_params.temperature
return params
async def convert_to_sambanova_messages(
self, messages: List[Message]
) -> List[dict]:
conversation = []
for message in messages:
content = {}
content["content"] = await self.convert_to_sambanova_content(message)
if isinstance(message, UserMessage):
content["role"] = "user"
elif isinstance(message, CompletionMessage):
content["role"] = "assistant"
tools = []
for tool_call in message.tool_calls:
tools.append(
{
"id": tool_call.call_id,
"function": {
"name": tool_call.name,
"arguments": json.dumps(tool_call.arguments),
},
"type": "function",
}
)
content["tool_calls"] = tools
elif isinstance(message, ToolResponseMessage):
content["role"] = "tool"
content["tool_call_id"] = message.call_id
elif isinstance(message, SystemMessage):
content["role"] = "system"
conversation.append(content)
return conversation
async def convert_to_sambanova_content(self, message: Message) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
url = await convert_image_content_to_url(content, download=True)
# A fix to make sure the call sucess.
components = url.split(";base64")
url = f"{components[0].lower()};base64{components[1]}"
return {
"type": "image_url",
"image_url": {"url": url},
}
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {"type": "text", "text": text}
if isinstance(message.content, list):
# If it is a list, the text content should be wrapped in dict
content = [await _convert_content(c) for c in message.content]
else:
content = message.content
return content
def convert_to_sambanova_tool(self, tools: List[ToolDefinition]) -> List[dict]:
if tools is None:
return tools
compatiable_tools = []
for tool in tools:
properties = {}
compatiable_required = []
if tool.parameters:
for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type}
if tool_param.description:
properties[tool_key]["description"] = tool_param.description
if tool_param.default:
properties[tool_key]["default"] = tool_param.default
if tool_param.required:
compatiable_required.append(tool_key)
compatiable_tool = {
"type": "function",
"function": {
"name": tool.tool_name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": compatiable_required,
},
},
}
compatiable_tools.append(compatiable_tool)
if len(compatiable_tools) > 0:
return compatiable_tools
return None
def convert_to_sambanova_finish_reason(self, finish_reason: str) -> StopReason:
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def convert_to_sambanova_tool_calls(
self,
tool_calls,
) -> List[ToolCall]:
if not tool_calls:
return []
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
compitable_tool_calls = [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call_function_arguments,
)
for call in tool_calls
]
return compitable_tool_calls

View file

@ -23,6 +23,7 @@ from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.groq import GroqConfig from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
@ -232,6 +233,23 @@ def inference_tgi() -> ProviderFixture:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def inference_sambanova() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="sambanova",
provider_type="remote::sambanova",
config=SambaNovaImplConfig(
api_key=get_env_or_fail("SAMBANOVA_API_KEY"),
).model_dump(),
)
],
provider_data=dict(
sambanova_api_key=get_env_or_fail("SAMBANOVA_API_KEY"),
),
)
def inference_sentence_transformers() -> ProviderFixture: def inference_sentence_transformers() -> ProviderFixture:
return ProviderFixture( return ProviderFixture(
providers=[ providers=[
@ -282,6 +300,7 @@ INFERENCE_FIXTURES = [
"cerebras", "cerebras",
"nvidia", "nvidia",
"tgi", "tgi",
"sambanova",
] ]

View file

@ -59,7 +59,7 @@ class TestModelRegistration:
}, },
) )
with pytest.raises(AssertionError) as exc_info: with pytest.raises(ValueError) as exc_info:
await models_impl.register_model( await models_impl.register_model(
model_id="custom-model-2", model_id="custom-model-2",
metadata={ metadata={

View file

@ -385,6 +385,12 @@ class TestInference:
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
if provider.__provider_spec__.provider_type == "remote::sambanova" and (
"-1B-" in inference_model or "-3B-" in inference_model
):
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling for 1B/ 3B
pytest.skip("Sambanova's tool calling for lightweight models don't work")
messages = sample_messages + [ messages = sample_messages + [
UserMessage( UserMessage(
content="What's the weather like in San Francisco?", content="What's the weather like in San Francisco?",
@ -431,6 +437,9 @@ class TestInference:
): ):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
if provider.__provider_spec__.provider_type == "remote::sambanova":
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling under streaming is supported (we are working on it)
pytest.skip("Sambanova's tool calling for streaming doesn't work")
messages = sample_messages + [ messages = sample_messages + [
UserMessage( UserMessage(

View file

@ -59,6 +59,7 @@ class TestVisionModelInference:
"remote::fireworks", "remote::fireworks",
"remote::ollama", "remote::ollama",
"remote::vllm", "remote::vllm",
"remote::sambanova",
): ):
pytest.skip( pytest.skip(
"Other inference providers don't support vision chat completion() yet" "Other inference providers don't support vision chat completion() yet"
@ -98,6 +99,7 @@ class TestVisionModelInference:
"remote::fireworks", "remote::fireworks",
"remote::ollama", "remote::ollama",
"remote::vllm", "remote::vllm",
"remote::sambanova",
): ):
pytest.skip( pytest.skip(
"Other inference providers don't support vision chat completion() yet" "Other inference providers don't support vision chat completion() yet"

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .sambanova import get_distribution_template # noqa: F401

View file

@ -0,0 +1,19 @@
version: '2'
name: sambanova
distribution_spec:
description: Use SambaNova.AI for running LLM inference
docker_image: null
providers:
inference:
- remote::sambanova
memory:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
image_type: conda

View file

@ -0,0 +1,68 @@
---
orphan: true
---
# SambaNova Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} ({{ model.provider_model_id }})`
{% endfor %}
{% endif %}
### Prerequisite: API Keys
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaBova.ai](https://sambanova.ai/).
## Running Llama Stack with SambaNova
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```
### Via Conda
```bash
llama stack build --template sambanova --image-type conda
llama stack run ./run.yaml \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```

View file

@ -0,0 +1,83 @@
version: '2'
image_name: sambanova
docker_image: null
conda_env: sambanova
apis:
- agents
- inference
- memory
- safety
- telemetry
providers:
inference:
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1/
api_key: ${env.SAMBANOVA_API_KEY}
memory:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-8B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-70B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-405B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.2-1B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.2-3B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: null
provider_model_id: Llama-3.2-11B-Vision-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: null
provider_model_id: Llama-3.2-90B-Vision-Instruct
shields:
- params: null
shield_id: meta-llama/Llama-Guard-3-8B
provider_id: null
provider_shield_id: null
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []

View file

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_models.sku_list import all_registered_models
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.sambanova.sambanova import MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::sambanova"],
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
}
inference_provider = Provider(
provider_id="sambanova",
provider_type="remote::sambanova",
config=SambaNovaImplConfig.sample_run_config(),
)
core_model_to_hf_repo = {
m.descriptor(): m.huggingface_repo for m in all_registered_models()
}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
provider_model_id=m.provider_model_id,
)
for m in MODEL_ALIASES
]
return DistributionTemplate(
name="sambanova",
distro_type="self_hosted",
description="Use SambaNova.AI for running LLM inference",
docker_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
default_models=default_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
},
default_models=default_models,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
),
},
run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"SAMBANOVA_API_KEY": (
"",
"SambaNova.AI API Key",
),
},
)