mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 18:49:48 +00:00
Implement SambaNova as new remote API Provider.
This commit is contained in:
parent
4e6c984c26
commit
b6a79d6291
8 changed files with 485 additions and 0 deletions
|
|
@ -161,4 +161,16 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="sambanova",
|
||||||
|
pip_packages=[
|
||||||
|
"openai",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.remote.inference.sambanova",
|
||||||
|
config_class="llama_stack.providers.remote.inference.sambanova.SambanovaImplConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.sambanova.SambanovaProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
17
llama_stack/providers/remote/inference/sambanova/__init__.py
Normal file
17
llama_stack/providers/remote/inference/sambanova/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .config import SambanovaImplConfig
|
||||||
|
|
||||||
|
class SambanovaProviderDataValidator(BaseModel):
|
||||||
|
sambanova_api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: SambanovaImplConfig, _deps):
|
||||||
|
from .sambanova import SambanovaInferenceAdapter
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, SambanovaImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
impl = SambanovaInferenceAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
23
llama_stack/providers/remote/inference/sambanova/config.py
Normal file
23
llama_stack/providers/remote/inference/sambanova/config.py
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SambanovaImplConfig(BaseModel):
|
||||||
|
url: str = Field(
|
||||||
|
default="https://api.sambanova.ai/v1",
|
||||||
|
description="The URL for the SambaNova API server",
|
||||||
|
)
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The SambaNova API Key",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"url": "https://api.sambanova.ai/v1",
|
||||||
|
"api_key": "${env.SAMBANOVA_API_KEY}",
|
||||||
|
}
|
||||||
290
llama_stack/providers/remote/inference/sambanova/sambanova.py
Normal file
290
llama_stack/providers/remote/inference/sambanova/sambanova.py
Normal file
|
|
@ -0,0 +1,290 @@
|
||||||
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from llama_models.datatypes import CoreModelId
|
||||||
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from llama_models.llama3.api.datatypes import Message
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import *
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ModelRegistryHelper,
|
||||||
|
build_model_alias,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
OpenAICompatCompletionChoice,
|
||||||
|
OpenAICompatCompletionResponse,
|
||||||
|
get_sampling_options,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
process_completion_response,
|
||||||
|
process_completion_stream_response,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
completion_request_to_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import SambanovaImplConfig
|
||||||
|
|
||||||
|
# Simplified model aliases - focus on core models
|
||||||
|
MODEL_ALIASES = [
|
||||||
|
build_model_alias(
|
||||||
|
"Meta-Llama-3.1-8B-Instruct",
|
||||||
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SambanovaInferenceAdapter(
|
||||||
|
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||||
|
):
|
||||||
|
def __init__(self, config: SambanovaImplConfig) -> None:
|
||||||
|
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||||
|
self.config = config
|
||||||
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
self.client = httpx.AsyncClient(
|
||||||
|
base_url=self.config.url,
|
||||||
|
timeout=httpx.Timeout(timeout=300.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
await self.client.aclose()
|
||||||
|
|
||||||
|
def _get_api_key(self) -> str:
|
||||||
|
if self.config.api_key is not None:
|
||||||
|
return self.config.api_key
|
||||||
|
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.sambanova_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass SambaNova API Key in the header X-LlamaStack-ProviderData as { "sambanova_api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
return provider_data.sambanova_api_key
|
||||||
|
|
||||||
|
def _convert_messages_to_api_format(self, messages: List[Message]) -> List[dict]:
|
||||||
|
"""Convert our Message objects to SambaNova API format."""
|
||||||
|
return [
|
||||||
|
{"role": message.role, "content": message.content} for message in messages
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_sampling_params(self, params: Optional[SamplingParams]) -> dict:
|
||||||
|
"""Convert our SamplingParams to SambaNova API parameters."""
|
||||||
|
if not params:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
api_params = {}
|
||||||
|
if params.max_tokens:
|
||||||
|
api_params["max_tokens"] = params.max_tokens
|
||||||
|
if params.temperature is not None:
|
||||||
|
api_params["temperature"] = params.temperature
|
||||||
|
if params.top_p is not None:
|
||||||
|
api_params["top_p"] = params.top_p
|
||||||
|
if params.top_k is not None:
|
||||||
|
api_params["top_k"] = params.top_k
|
||||||
|
if params.stop_sequences:
|
||||||
|
api_params["stop"] = params.stop_sequences
|
||||||
|
|
||||||
|
return api_params
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: InterleavedTextMedia,
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
request = CompletionRequest(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
content=content,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._stream_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
|
async def _get_params(
|
||||||
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||||
|
) -> dict:
|
||||||
|
sampling_options = get_sampling_options(request.sampling_params)
|
||||||
|
|
||||||
|
input_dict = {}
|
||||||
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
if isinstance(request.messages[0].content, list):
|
||||||
|
raise NotImplementedError("Media content not supported for SambaNova")
|
||||||
|
input_dict["messages"] = self._convert_messages_to_api_format(
|
||||||
|
request.messages
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": request.model,
|
||||||
|
**input_dict,
|
||||||
|
**sampling_options,
|
||||||
|
"stream": request.stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
try:
|
||||||
|
response = await self.client.post(
|
||||||
|
"/completions",
|
||||||
|
json=params,
|
||||||
|
headers={"Authorization": f"Bearer {self._get_api_key()}"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=data.get("choices", [{}])[0].get("finish_reason"),
|
||||||
|
text=data.get("choices", [{}])[0].get("text", ""),
|
||||||
|
)
|
||||||
|
response = OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
return process_completion_response(response, self.formatter)
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
await self._handle_api_error(e)
|
||||||
|
|
||||||
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
|
||||||
|
async def _to_async_generator():
|
||||||
|
try:
|
||||||
|
async with self.client.stream(
|
||||||
|
"POST",
|
||||||
|
"/completions",
|
||||||
|
json=params,
|
||||||
|
headers={"Authorization": f"Bearer {self._get_api_key()}"},
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line:
|
||||||
|
data = httpx.loads(line)
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=data.get("choices", [{}])[0].get(
|
||||||
|
"finish_reason"
|
||||||
|
),
|
||||||
|
text=data.get("choices", [{}])[0].get("text", ""),
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(choices=[choice])
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
await self._handle_api_error(e)
|
||||||
|
|
||||||
|
stream = _to_async_generator()
|
||||||
|
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
try:
|
||||||
|
response = await self.client.post(
|
||||||
|
"/chat/completions",
|
||||||
|
json=params,
|
||||||
|
headers={"Authorization": f"Bearer {self._get_api_key()}"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=data.get("choices", [{}])[0].get("finish_reason"),
|
||||||
|
text=data.get("choices", [{}])[0].get("message", {}).get("content", ""),
|
||||||
|
)
|
||||||
|
response = OpenAICompatCompletionResponse(choices=[choice])
|
||||||
|
return process_chat_completion_response(response, self.formatter)
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
await self._handle_api_error(e)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
|
||||||
|
async def _to_async_generator():
|
||||||
|
try:
|
||||||
|
async with self.client.stream(
|
||||||
|
"POST",
|
||||||
|
"/chat/completions",
|
||||||
|
json=params,
|
||||||
|
headers={"Authorization": f"Bearer {self._get_api_key()}"},
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line:
|
||||||
|
data = httpx.loads(line)
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=data.get("choices", [{}])[0].get(
|
||||||
|
"finish_reason"
|
||||||
|
),
|
||||||
|
text=data.get("choices", [{}])[0]
|
||||||
|
.get("message", {})
|
||||||
|
.get("content", ""),
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(choices=[choice])
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
await self._handle_api_error(e)
|
||||||
|
|
||||||
|
stream = _to_async_generator()
|
||||||
|
async for chunk in process_chat_completion_stream_response(
|
||||||
|
stream, self.formatter
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _handle_api_error(self, e: httpx.HTTPError) -> None:
|
||||||
|
if e.response.status_code in (401, 403):
|
||||||
|
raise ValueError("Invalid API key or unauthorized access") from e
|
||||||
|
elif e.response.status_code == 429:
|
||||||
|
raise ValueError("Rate limit exceeded") from e
|
||||||
|
elif e.response.status_code == 400:
|
||||||
|
error_data = e.response.json()
|
||||||
|
raise ValueError(
|
||||||
|
f"Bad request: {error_data.get('error', {}).get('message', 'Unknown error')}"
|
||||||
|
) from e
|
||||||
|
raise RuntimeError(f"SambaNova API error: {str(e)}") from e
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
raise NotImplementedError("Embeddings not supported for SambaNova")
|
||||||
1
llama_stack/templates/sambanova/__init__.py
Normal file
1
llama_stack/templates/sambanova/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
from .sambanova import get_distribution_template # noqa: F401
|
||||||
19
llama_stack/templates/sambanova/build.yaml
Normal file
19
llama_stack/templates/sambanova/build.yaml
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
version: '2'
|
||||||
|
name: sambanova
|
||||||
|
distribution_spec:
|
||||||
|
description: Use SambaNova for running LLM inference
|
||||||
|
docker_image: null
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::sambanova
|
||||||
|
memory:
|
||||||
|
- inline::faiss
|
||||||
|
- remote::chromadb
|
||||||
|
- remote::pgvector
|
||||||
|
safety:
|
||||||
|
- inline::llama-guard
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
image_type: conda
|
||||||
59
llama_stack/templates/sambanova/run.yaml
Normal file
59
llama_stack/templates/sambanova/run.yaml
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: sambanova
|
||||||
|
docker_image: null
|
||||||
|
conda_env: sambanova
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- inference
|
||||||
|
- memory
|
||||||
|
- safety
|
||||||
|
- telemetry
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: sambanova
|
||||||
|
provider_type: remote::sambanova
|
||||||
|
config:
|
||||||
|
url: https://api.sambanova.ai/v1
|
||||||
|
api_key: ${env.SAMBANOVA_API_KEY}
|
||||||
|
memory:
|
||||||
|
- provider_id: faiss
|
||||||
|
provider_type: inline::faiss
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/faiss_store.db
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config: {}
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
namespace: null
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/registry.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: Meta-Llama-3.1-8B-Instruct
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: Meta-Llama-3.1-8B-Instruct
|
||||||
|
shields:
|
||||||
|
- params: null
|
||||||
|
shield_id: meta-llama/Llama-Guard-3-8B
|
||||||
|
provider_id: null
|
||||||
|
provider_shield_id: null
|
||||||
|
memory_banks: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
eval_tasks: []
|
||||||
64
llama_stack/templates/sambanova/sambanova.py
Normal file
64
llama_stack/templates/sambanova/sambanova.py
Normal file
|
|
@ -0,0 +1,64 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||||
|
from llama_stack.providers.remote.inference.sambanova import SambanovaImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.sambanova.sambanova import MODEL_ALIASES
|
||||||
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
providers = {
|
||||||
|
"inference": ["remote::sambanova"],
|
||||||
|
"memory": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
}
|
||||||
|
|
||||||
|
inference_provider = Provider(
|
||||||
|
provider_id="sambanova",
|
||||||
|
provider_type="remote::sambanova",
|
||||||
|
config=SambanovaImplConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
core_model_to_hf_repo = {
|
||||||
|
m.descriptor(): m.huggingface_repo for m in all_registered_models()
|
||||||
|
}
|
||||||
|
default_models = [
|
||||||
|
ModelInput(
|
||||||
|
model_id=core_model_to_hf_repo[m.llama_model],
|
||||||
|
provider_model_id=m.provider_model_id,
|
||||||
|
)
|
||||||
|
for m in MODEL_ALIASES
|
||||||
|
]
|
||||||
|
|
||||||
|
return DistributionTemplate(
|
||||||
|
name="sambanova",
|
||||||
|
distro_type="self_hosted",
|
||||||
|
description="Use SambaNova for running LLM inference",
|
||||||
|
docker_image=None,
|
||||||
|
template_path=Path(__file__).parent / "doc_template.md",
|
||||||
|
providers=providers,
|
||||||
|
default_models=default_models,
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": [inference_provider],
|
||||||
|
},
|
||||||
|
default_models=default_models,
|
||||||
|
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
run_config_env_vars={
|
||||||
|
"LLAMASTACK_PORT": (
|
||||||
|
"5001",
|
||||||
|
"Port for the Llama Stack distribution server",
|
||||||
|
),
|
||||||
|
"SAMBANOVA_API_KEY": (
|
||||||
|
"",
|
||||||
|
"SambaNova API Key",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue