Sambanova inference provider (#555)

# What does this PR do?

This PR adds SambaNova as one of the Provider

- Add SambaNova as a provider

## Test Plan
Test the functional command
```
pytest -s -v --providers inference=sambanova llama_stack/providers/tests/inference/test_embeddings.py llama_stack/providers/tests/inference/test_prompt_adapter.py llama_stack/providers/tests/inference/test_text_inference.py llama_stack/providers/tests/inference/test_vision_inference.py --env SAMBANOVA_API_KEY=<sambanova-api-key>
```

Test the distribution template:
```
# Docker
LLAMA_STACK_PORT=5001
docker run -it -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
  llamastack/distribution-sambanova \
  --port $LLAMA_STACK_PORT \
  --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY

# Conda
llama stack build --template sambanova --image-type conda
llama stack run ./run.yaml \
  --port $LLAMA_STACK_PORT \
  --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```

## Source
[SambaNova API Documentation](https://cloud.sambanova.ai/apis)

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [Y] Ran pre-commit to handle lint / formatting issues.
- [Y] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [Y] Updated relevant documentation.
- [Y ] Wrote necessary unit or integration tests.

---------

Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
snova-edwardm 2025-01-23 12:20:28 -08:00 committed by GitHub
parent e2b5456e48
commit 22dc684da6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 870 additions and 2 deletions

View file

@ -0,0 +1,19 @@
version: '2'
name: sambanova
distribution_spec:
description: Use SambaNova.AI for running LLM inference
docker_image: null
providers:
inference:
- remote::sambanova
memory:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
image_type: conda

View file

@ -0,0 +1,16 @@
services:
llamastack:
image: llamastack/distribution-sambanova
network_mode: "host"
volumes:
- ~/.llama:/root/.llama
- ./run.yaml:/root/llamastack-run-sambanova.yaml
ports:
- "5000:5000"
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-sambanova.yaml"
deploy:
restart_policy:
condition: on-failure
delay: 3s
max_attempts: 5
window: 60s

View file

@ -0,0 +1,83 @@
version: '2'
image_name: sambanova
docker_image: null
conda_env: sambanova
apis:
- agents
- inference
- memory
- safety
- telemetry
providers:
inference:
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1/
api_key: ${env.SAMBANOVA_API_KEY}
memory:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-8B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-70B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-405B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.2-1B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.2-3B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: null
provider_model_id: Llama-3.2-11B-Vision-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: null
provider_model_id: Llama-3.2-90B-Vision-Instruct
shields:
- params: null
shield_id: meta-llama/Llama-Guard-3-8B
provider_id: null
provider_shield_id: null
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []

View file

@ -24,7 +24,7 @@ We are working on adding a few more APIs to complete the application lifecycle.
## API Providers
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Obvious examples for these include
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, etc.),
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, SambaNova, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)

View file

@ -0,0 +1,74 @@
---
orphan: true
---
# SambaNova Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-sambanova` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| inference | `remote::sambanova` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `SAMBANOVA_API_KEY`: SambaNova.AI API Key (default: ``)
### Models
The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct`
- `meta-llama/Llama-3.1-70B-Instruct`
- `meta-llama/Llama-3.1-405B-Instruct`
- `meta-llama/Llama-3.2-1B-Instruct`
- `meta-llama/Llama-3.2-3B-Instruct`
- `meta-llama/Llama-3.2-11B-Vision-Instruct`
- `meta-llama/Llama-3.2-90B-Vision-Instruct`
### Prerequisite: API Keys
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaBova.ai](https://sambanova.ai/).
## Running Llama Stack with SambaNova
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-sambanova \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```
### Via Conda
```bash
llama stack build --template sambanova --image-type conda
llama stack run ./run.yaml \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```

View file

@ -40,6 +40,7 @@ A number of "adapters" are available for some popular Inference and Memory (Vect
| Fireworks | Hosted | Y | Y | Y | | |
| AWS Bedrock | Hosted | | Y | | Y | |
| Together | Hosted | Y | Y | | Y | |
| SambaNova | Hosted | | Y | | | |
| Ollama | Single Node | | Y | | |
| TGI | Hosted and Single Node | | Y | | |
| NVIDIA NIM | Hosted and Single Node | | Y | | |

View file

@ -18,6 +18,7 @@ class LlamaStackApi:
provider_data={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
},
)

View file

@ -204,4 +204,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.inference.runpod.RunpodImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=[
"openai",
],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
),
),
]

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import SambaNovaImplConfig
from .sambanova import SambaNovaInferenceAdapter
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: str
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
assert isinstance(
config, SambaNovaImplConfig
), f"Unexpected config type: {type(config)}"
impl = SambaNovaInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class SambaNovaImplConfig(BaseModel):
url: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: Optional[str] = Field(
default=None,
description="The SambaNova.ai API Key",
)
@classmethod
def sample_run_config(cls) -> Dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": "${env.SAMBANOVA_API_KEY}",
}

View file

@ -0,0 +1,333 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import AsyncGenerator
from llama_models.datatypes import CoreModelId, SamplingStrategy
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
TextContentItem,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
)
from .config import SambaNovaImplConfig
MODEL_ALIASES = [
build_model_alias(
"Meta-Llama-3.1-8B-Instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.1-70B-Instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.1-405B-Instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.2-1B-Instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"Meta-Llama-3.2-3B-Instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"Llama-3.2-11B-Vision-Instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"Llama-3.2-90B-Vision-Instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
]
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: SambaNovaImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=MODEL_ALIASES,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
def _get_client(self) -> OpenAI:
return OpenAI(base_url=self.config.url, api_key=self.config.api_key)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
request_sambanova = await self.convert_chat_completion_request(request)
if stream:
return self._stream_chat_completion(request_sambanova)
else:
return await self._nonstream_chat_completion(request_sambanova)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
response = self._get_client().chat.completions.create(**request)
choice = response.choices[0]
result = ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "",
stop_reason=self.convert_to_sambanova_finish_reason(
choice.finish_reason
),
tool_calls=self.convert_to_sambanova_tool_calls(
choice.message.tool_calls
),
),
logprobs=None,
)
return result
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async def _to_async_generator():
streaming = self._get_client().chat.completions.create(**request)
for chunk in streaming:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def convert_chat_completion_request(
self, request: ChatCompletionRequest
) -> dict:
compatible_request = self.convert_sampling_params(request.sampling_params)
compatible_request["model"] = request.model
compatible_request["messages"] = await self.convert_to_sambanova_messages(
request.messages
)
compatible_request["stream"] = request.stream
compatible_request["logprobs"] = False
compatible_request["extra_headers"] = {
b"User-Agent": b"llama-stack: sambanova-inference-adapter",
}
compatible_request["tools"] = self.convert_to_sambanova_tool(request.tools)
return compatible_request
def convert_sampling_params(
self, sampling_params: SamplingParams, legacy: bool = False
) -> dict:
params = {}
if sampling_params:
params["frequency_penalty"] = sampling_params.repetition_penalty
if sampling_params.max_tokens:
if legacy:
params["max_tokens"] = sampling_params.max_tokens
else:
params["max_completion_tokens"] = sampling_params.max_tokens
if sampling_params.strategy == SamplingStrategy.top_p:
params["top_p"] = sampling_params.top_p
elif sampling_params.strategy == "top_k":
params["extra_body"]["top_k"] = sampling_params.top_k
elif sampling_params.strategy == "greedy":
params["temperature"] = sampling_params.temperature
return params
async def convert_to_sambanova_messages(
self, messages: List[Message]
) -> List[dict]:
conversation = []
for message in messages:
content = {}
content["content"] = await self.convert_to_sambanova_content(message)
if isinstance(message, UserMessage):
content["role"] = "user"
elif isinstance(message, CompletionMessage):
content["role"] = "assistant"
tools = []
for tool_call in message.tool_calls:
tools.append(
{
"id": tool_call.call_id,
"function": {
"name": tool_call.name,
"arguments": json.dumps(tool_call.arguments),
},
"type": "function",
}
)
content["tool_calls"] = tools
elif isinstance(message, ToolResponseMessage):
content["role"] = "tool"
content["tool_call_id"] = message.call_id
elif isinstance(message, SystemMessage):
content["role"] = "system"
conversation.append(content)
return conversation
async def convert_to_sambanova_content(self, message: Message) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
url = await convert_image_content_to_url(content, download=True)
# A fix to make sure the call sucess.
components = url.split(";base64")
url = f"{components[0].lower()};base64{components[1]}"
return {
"type": "image_url",
"image_url": {"url": url},
}
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {"type": "text", "text": text}
if isinstance(message.content, list):
# If it is a list, the text content should be wrapped in dict
content = [await _convert_content(c) for c in message.content]
else:
content = message.content
return content
def convert_to_sambanova_tool(self, tools: List[ToolDefinition]) -> List[dict]:
if tools is None:
return tools
compatiable_tools = []
for tool in tools:
properties = {}
compatiable_required = []
if tool.parameters:
for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type}
if tool_param.description:
properties[tool_key]["description"] = tool_param.description
if tool_param.default:
properties[tool_key]["default"] = tool_param.default
if tool_param.required:
compatiable_required.append(tool_key)
compatiable_tool = {
"type": "function",
"function": {
"name": tool.tool_name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": compatiable_required,
},
},
}
compatiable_tools.append(compatiable_tool)
if len(compatiable_tools) > 0:
return compatiable_tools
return None
def convert_to_sambanova_finish_reason(self, finish_reason: str) -> StopReason:
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def convert_to_sambanova_tool_calls(
self,
tool_calls,
) -> List[ToolCall]:
if not tool_calls:
return []
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
compitable_tool_calls = [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call_function_arguments,
)
for call in tool_calls
]
return compitable_tool_calls

View file

@ -23,6 +23,7 @@ from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
@ -232,6 +233,23 @@ def inference_tgi() -> ProviderFixture:
@pytest.fixture(scope="session")
def inference_sambanova() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="sambanova",
provider_type="remote::sambanova",
config=SambaNovaImplConfig(
api_key=get_env_or_fail("SAMBANOVA_API_KEY"),
).model_dump(),
)
],
provider_data=dict(
sambanova_api_key=get_env_or_fail("SAMBANOVA_API_KEY"),
),
)
def inference_sentence_transformers() -> ProviderFixture:
return ProviderFixture(
providers=[
@ -282,6 +300,7 @@ INFERENCE_FIXTURES = [
"cerebras",
"nvidia",
"tgi",
"sambanova",
]

View file

@ -59,7 +59,7 @@ class TestModelRegistration:
},
)
with pytest.raises(AssertionError) as exc_info:
with pytest.raises(ValueError) as exc_info:
await models_impl.register_model(
model_id="custom-model-2",
metadata={

View file

@ -385,6 +385,12 @@ class TestInference:
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
if provider.__provider_spec__.provider_type == "remote::sambanova" and (
"-1B-" in inference_model or "-3B-" in inference_model
):
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling for 1B/ 3B
pytest.skip("Sambanova's tool calling for lightweight models don't work")
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
@ -431,6 +437,9 @@ class TestInference:
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
if provider.__provider_spec__.provider_type == "remote::sambanova":
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling under streaming is supported (we are working on it)
pytest.skip("Sambanova's tool calling for streaming doesn't work")
messages = sample_messages + [
UserMessage(

View file

@ -59,6 +59,7 @@ class TestVisionModelInference:
"remote::fireworks",
"remote::ollama",
"remote::vllm",
"remote::sambanova",
):
pytest.skip(
"Other inference providers don't support vision chat completion() yet"
@ -98,6 +99,7 @@ class TestVisionModelInference:
"remote::fireworks",
"remote::ollama",
"remote::vllm",
"remote::sambanova",
):
pytest.skip(
"Other inference providers don't support vision chat completion() yet"

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .sambanova import get_distribution_template # noqa: F401

View file

@ -0,0 +1,19 @@
version: '2'
name: sambanova
distribution_spec:
description: Use SambaNova.AI for running LLM inference
docker_image: null
providers:
inference:
- remote::sambanova
memory:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
image_type: conda

View file

@ -0,0 +1,68 @@
---
orphan: true
---
# SambaNova Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} ({{ model.provider_model_id }})`
{% endfor %}
{% endif %}
### Prerequisite: API Keys
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaBova.ai](https://sambanova.ai/).
## Running Llama Stack with SambaNova
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```
### Via Conda
```bash
llama stack build --template sambanova --image-type conda
llama stack run ./run.yaml \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```

View file

@ -0,0 +1,83 @@
version: '2'
image_name: sambanova
docker_image: null
conda_env: sambanova
apis:
- agents
- inference
- memory
- safety
- telemetry
providers:
inference:
- provider_id: sambanova
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1/
api_key: ${env.SAMBANOVA_API_KEY}
memory:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-8B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-70B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-405B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.2-1B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.2-3B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: null
provider_model_id: Llama-3.2-11B-Vision-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: null
provider_model_id: Llama-3.2-90B-Vision-Instruct
shields:
- params: null
shield_id: meta-llama/Llama-Guard-3-8B
provider_id: null
provider_shield_id: null
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []

View file

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_models.sku_list import all_registered_models
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.sambanova.sambanova import MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::sambanova"],
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
}
inference_provider = Provider(
provider_id="sambanova",
provider_type="remote::sambanova",
config=SambaNovaImplConfig.sample_run_config(),
)
core_model_to_hf_repo = {
m.descriptor(): m.huggingface_repo for m in all_registered_models()
}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
provider_model_id=m.provider_model_id,
)
for m in MODEL_ALIASES
]
return DistributionTemplate(
name="sambanova",
distro_type="self_hosted",
description="Use SambaNova.AI for running LLM inference",
docker_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
default_models=default_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
},
default_models=default_models,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
),
},
run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"SAMBANOVA_API_KEY": (
"",
"SambaNova.AI API Key",
),
},
)