Implement additional functionality supported by Sambanova.

This commit is contained in:
swanhtet1992 2024-11-24 01:55:36 -06:00
parent b6a79d6291
commit 8920c4216f
9 changed files with 565 additions and 203 deletions

View file

@ -1 +1,7 @@
from .sambanova import get_distribution_template # noqa: F401
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .sambanova import get_distribution_template # noqa: F401

View file

@ -16,4 +16,4 @@ distribution_spec:
- inline::meta-reference
telemetry:
- inline::meta-reference
image_type: conda
image_type: conda

View file

@ -0,0 +1,66 @@
---
orphan: true
---
# SambaNova Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }}`
{% endfor %}
{% endif %}
### Prerequisite: API Keys
Make sure you have access to a SambaNova API Key. You can get one by contacting SambaNova Systems.
## Running Llama Stack with SambaNova
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```
### Via Conda
```bash
llama stack build --template {{ name }} --image-type conda
llama stack run ./run.yaml \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```

View file

@ -45,15 +45,35 @@ metadata_store:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db
models:
- metadata: {}
model_id: Meta-Llama-3.1-8B-Instruct
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.2-1B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.2-3B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: null
provider_model_id: Llama-3.2-11B-Vision-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: null
provider_model_id: Llama-3.2-90B-Vision-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-8B-Instruct
shields:
- params: null
shield_id: meta-llama/Llama-Guard-3-8B
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: null
provider_shield_id: null
provider_model_id: Meta-Llama-3.1-70B-Instruct
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: null
provider_model_id: Meta-Llama-3.1-405B-Instruct
shields: []
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
eval_tasks: []

View file

@ -1,11 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_models.sku_list import all_registered_models
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.remote.inference.sambanova import SambanovaImplConfig
from llama_stack.providers.remote.inference.sambanova.sambanova import MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.remote.inference.sambanova.sambanova import (
MODEL_ALIASES,
)
from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
)
def get_distribution_template() -> DistributionTemplate:
@ -26,6 +37,7 @@ def get_distribution_template() -> DistributionTemplate:
core_model_to_hf_repo = {
m.descriptor(): m.huggingface_repo for m in all_registered_models()
}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
@ -48,7 +60,9 @@ def get_distribution_template() -> DistributionTemplate:
"inference": [inference_provider],
},
default_models=default_models,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
default_shields=[
ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")
],
),
},
run_config_env_vars={
@ -58,7 +72,7 @@ def get_distribution_template() -> DistributionTemplate:
),
"SAMBANOVA_API_KEY": (
"",
"SambaNova API Key",
"SambaNova API Key for authentication",
),
},
)