mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 23:39:48 +00:00
Cerebras Integration
This commit is contained in:
parent
34be07e0df
commit
3838bd1704
16 changed files with 515 additions and 65 deletions
|
|
@ -80,6 +80,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
|
||||||
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
||||||
| :----: | :----: | :----: | :----: | :----: | :----: | :----: |
|
| :----: | :----: | :----: | :----: | :----: | :----: | :----: |
|
||||||
| Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
| Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||||
|
| Cerebras | Hosted | | :heavy_check_mark: | | | |
|
||||||
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||||
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
|
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||||
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
|
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||||
|
|
@ -95,6 +96,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
|
||||||
|:----------------: |:------------------------------------------: |:-----------------------: |
|
|:----------------: |:------------------------------------------: |:-----------------------: |
|
||||||
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
||||||
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
|
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
|
||||||
|
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) |
|
||||||
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/ollama.html) |
|
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/ollama.html) |
|
||||||
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/tgi.html) |
|
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/tgi.html) |
|
||||||
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/remote_hosted_distro/together.html) |
|
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/remote_hosted_distro/together.html) |
|
||||||
|
|
|
||||||
1
distributions/cerebras/build.yaml
Symbolic link
1
distributions/cerebras/build.yaml
Symbolic link
|
|
@ -0,0 +1 @@
|
||||||
|
../../llama_stack/templates/cerebras/build.yaml
|
||||||
16
distributions/cerebras/compose.yaml
Normal file
16
distributions/cerebras/compose.yaml
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
services:
|
||||||
|
llamastack:
|
||||||
|
image: llamastack/distribution-cerebras
|
||||||
|
network_mode: "host"
|
||||||
|
volumes:
|
||||||
|
- ~/.llama:/root/.llama
|
||||||
|
- ./run.yaml:/root/llamastack-run-cerebras.yaml
|
||||||
|
ports:
|
||||||
|
- "5000:5000"
|
||||||
|
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-cerebras.yaml"
|
||||||
|
deploy:
|
||||||
|
restart_policy:
|
||||||
|
condition: on-failure
|
||||||
|
delay: 3s
|
||||||
|
max_attempts: 5
|
||||||
|
window: 60s
|
||||||
1
distributions/cerebras/run.yaml
Symbolic link
1
distributions/cerebras/run.yaml
Symbolic link
|
|
@ -0,0 +1 @@
|
||||||
|
../../llama_stack/templates/cerebras/run.yaml
|
||||||
|
|
@ -36,6 +36,7 @@
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"huggingface_hub",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"nltk",
|
"nltk",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
|
@ -47,7 +48,6 @@
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"together",
|
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
|
|
@ -163,33 +163,6 @@
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||||
],
|
],
|
||||||
"bedrock": [
|
|
||||||
"aiosqlite",
|
|
||||||
"blobfile",
|
|
||||||
"boto3",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"matplotlib",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pypdf",
|
|
||||||
"redis",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"meta-reference-gpu": [
|
"meta-reference-gpu": [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
|
|
@ -222,19 +195,15 @@
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||||
],
|
],
|
||||||
"meta-reference-quantized-gpu": [
|
"together": [
|
||||||
"accelerate",
|
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"blobfile",
|
"blobfile",
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"fairscale",
|
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fbgemm-gpu",
|
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
"lm-format-enforcer",
|
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"nltk",
|
"nltk",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
|
@ -246,13 +215,10 @@
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"torch",
|
"together",
|
||||||
"torchao==0.5.0",
|
|
||||||
"torchvision",
|
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"zmq",
|
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||||
],
|
],
|
||||||
|
|
@ -283,33 +249,5 @@
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||||
],
|
|
||||||
"hf-endpoint": [
|
|
||||||
"aiohttp",
|
|
||||||
"aiosqlite",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"huggingface_hub",
|
|
||||||
"matplotlib",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pypdf",
|
|
||||||
"redis",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,17 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
|
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="cerebras",
|
||||||
|
pip_packages=[
|
||||||
|
"cerebras_cloud_sdk",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.remote.inference.cerebras",
|
||||||
|
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
||||||
21
llama_stack/providers/remote/inference/cerebras/__init__.py
Normal file
21
llama_stack/providers/remote/inference/cerebras/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .config import CerebrasImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: CerebrasImplConfig, _deps):
|
||||||
|
from .cerebras import CerebrasInferenceAdapter
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
config, CerebrasImplConfig
|
||||||
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
|
impl = CerebrasInferenceAdapter(config)
|
||||||
|
|
||||||
|
await impl.initialize()
|
||||||
|
|
||||||
|
return impl
|
||||||
191
llama_stack/providers/remote/inference/cerebras/cerebras.py
Normal file
191
llama_stack/providers/remote/inference/cerebras/cerebras.py
Normal file
|
|
@ -0,0 +1,191 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from cerebras.cloud.sdk import AsyncCerebras
|
||||||
|
|
||||||
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import Message
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
build_model_alias,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
get_sampling_options,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
process_completion_response,
|
||||||
|
process_completion_stream_response,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
|
completion_request_to_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import CerebrasImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
model_aliases = [
|
||||||
|
build_model_alias(
|
||||||
|
"llama3.1-8b",
|
||||||
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"llama3.1-70b",
|
||||||
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||||
|
ModelRegistryHelper.__init__(
|
||||||
|
self,
|
||||||
|
model_aliases=model_aliases,
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
|
self.client = AsyncCerebras(
|
||||||
|
base_url=self.config.base_url, api_key=self.config.api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: InterleavedTextMedia,
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
request = CompletionRequest(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
content=content,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._stream_completion(
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
|
async def _nonstream_completion(
|
||||||
|
self, request: CompletionRequest
|
||||||
|
) -> CompletionResponse:
|
||||||
|
params = self._get_params(request)
|
||||||
|
|
||||||
|
r = await self.client.completions.create(**params)
|
||||||
|
|
||||||
|
return process_completion_response(r, self.formatter)
|
||||||
|
|
||||||
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
params = self._get_params(request)
|
||||||
|
|
||||||
|
stream = await self.client.completions.create(**params)
|
||||||
|
|
||||||
|
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: CompletionRequest
|
||||||
|
) -> CompletionResponse:
|
||||||
|
params = self._get_params(request)
|
||||||
|
|
||||||
|
r = await self.client.completions.create(**params)
|
||||||
|
|
||||||
|
return process_chat_completion_response(r, self.formatter)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(
|
||||||
|
self, request: CompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
params = self._get_params(request)
|
||||||
|
|
||||||
|
stream = await self.client.completions.create(**params)
|
||||||
|
|
||||||
|
async for chunk in process_chat_completion_stream_response(
|
||||||
|
stream, self.formatter
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def _get_params(
|
||||||
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||||
|
) -> dict:
|
||||||
|
if request.sampling_params and request.sampling_params.top_k:
|
||||||
|
raise ValueError("`top_k` not supported by Cerebras")
|
||||||
|
|
||||||
|
prompt = ""
|
||||||
|
if type(request) == ChatCompletionRequest:
|
||||||
|
prompt = chat_completion_request_to_prompt(
|
||||||
|
request, self.get_llama_model(request.model), self.formatter
|
||||||
|
)
|
||||||
|
elif type(request) == CompletionRequest:
|
||||||
|
prompt = completion_request_to_prompt(request, self.formatter)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown request type {type(request)}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": request.model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": request.stream,
|
||||||
|
**get_sampling_options(request.sampling_params),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
32
llama_stack/providers/remote/inference/cerebras/config.py
Normal file
32
llama_stack/providers/remote/inference/cerebras/config.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CerebrasImplConfig(BaseModel):
|
||||||
|
base_url: str = Field(
|
||||||
|
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||||
|
description="Base URL for the Cerebras API",
|
||||||
|
)
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=os.environ.get("CEREBRAS_API_KEY"),
|
||||||
|
description="Cerebras API Key",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"base_url": DEFAULT_BASE_URL,
|
||||||
|
"api_key": "${env.CEREBRAS_API_KEY}",
|
||||||
|
}
|
||||||
|
|
@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.meta_reference import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
|
|
@ -63,6 +64,21 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def inference_cerebras() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="cerebras",
|
||||||
|
provider_type="remote::cerebras",
|
||||||
|
config=CerebrasImplConfig(
|
||||||
|
api_key=get_env_or_fail("CEREBRAS_API_KEY"),
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_ollama(inference_model) -> ProviderFixture:
|
def inference_ollama(inference_model) -> ProviderFixture:
|
||||||
inference_model = (
|
inference_model = (
|
||||||
|
|
@ -189,6 +205,7 @@ INFERENCE_FIXTURES = [
|
||||||
"vllm_remote",
|
"vllm_remote",
|
||||||
"remote",
|
"remote",
|
||||||
"bedrock",
|
"bedrock",
|
||||||
|
"cerebras",
|
||||||
"nvidia",
|
"nvidia",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -94,6 +94,7 @@ class TestInference:
|
||||||
"remote::tgi",
|
"remote::tgi",
|
||||||
"remote::together",
|
"remote::together",
|
||||||
"remote::fireworks",
|
"remote::fireworks",
|
||||||
|
"remote::cerebras",
|
||||||
):
|
):
|
||||||
pytest.skip("Other inference providers don't support completion() yet")
|
pytest.skip("Other inference providers don't support completion() yet")
|
||||||
|
|
||||||
|
|
@ -139,6 +140,7 @@ class TestInference:
|
||||||
"remote::tgi",
|
"remote::tgi",
|
||||||
"remote::together",
|
"remote::together",
|
||||||
"remote::fireworks",
|
"remote::fireworks",
|
||||||
|
"remote::cerebras",
|
||||||
):
|
):
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Other inference providers don't support structured output in completions yet"
|
"Other inference providers don't support structured output in completions yet"
|
||||||
|
|
|
||||||
7
llama_stack/templates/cerebras/__init__.py
Normal file
7
llama_stack/templates/cerebras/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .cerebras import get_distribution_template # noqa: F401
|
||||||
17
llama_stack/templates/cerebras/build.yaml
Normal file
17
llama_stack/templates/cerebras/build.yaml
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
version: '2'
|
||||||
|
name: cerebras
|
||||||
|
distribution_spec:
|
||||||
|
description: Use Cerebras for running LLM inference
|
||||||
|
docker_image: null
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::cerebras
|
||||||
|
safety:
|
||||||
|
- inline::llama-guard
|
||||||
|
memory:
|
||||||
|
- inline::meta-reference
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
image_type: conda
|
||||||
71
llama_stack/templates/cerebras/cerebras.py
Normal file
71
llama_stack/templates/cerebras/cerebras.py
Normal file
|
|
@ -0,0 +1,71 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||||
|
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||||
|
from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases
|
||||||
|
|
||||||
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
providers = {
|
||||||
|
"inference": ["remote::cerebras"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"memory": ["inline::meta-reference"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
}
|
||||||
|
|
||||||
|
inference_provider = Provider(
|
||||||
|
provider_id="cerebras",
|
||||||
|
provider_type="remote::cerebras",
|
||||||
|
config=CerebrasImplConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
core_model_to_hf_repo = {
|
||||||
|
m.descriptor(): m.huggingface_repo for m in all_registered_models()
|
||||||
|
}
|
||||||
|
default_models = [
|
||||||
|
ModelInput(
|
||||||
|
model_id=core_model_to_hf_repo[m.llama_model],
|
||||||
|
provider_model_id=m.provider_model_id,
|
||||||
|
)
|
||||||
|
for m in model_aliases
|
||||||
|
]
|
||||||
|
|
||||||
|
return DistributionTemplate(
|
||||||
|
name="cerebras",
|
||||||
|
distro_type="self_hosted",
|
||||||
|
description="Use Cerebras for running LLM inference",
|
||||||
|
docker_image=None,
|
||||||
|
template_path=Path(__file__).parent / "doc_template.md",
|
||||||
|
providers=providers,
|
||||||
|
default_models=default_models,
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": [inference_provider],
|
||||||
|
},
|
||||||
|
default_models=default_models,
|
||||||
|
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||||
|
),
|
||||||
|
},
|
||||||
|
run_config_env_vars={
|
||||||
|
"LLAMASTACK_PORT": (
|
||||||
|
"5001",
|
||||||
|
"Port for the Llama Stack distribution server",
|
||||||
|
),
|
||||||
|
"CEREBRAS_API_KEY": (
|
||||||
|
"",
|
||||||
|
"Cerebras API Key",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
60
llama_stack/templates/cerebras/doc_template.md
Normal file
60
llama_stack/templates/cerebras/doc_template.md
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
# Cerebras Distribution
|
||||||
|
|
||||||
|
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
{{ providers_table }}
|
||||||
|
|
||||||
|
{% if run_config_env_vars %}
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The following environment variables can be configured:
|
||||||
|
|
||||||
|
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||||
|
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if default_models %}
|
||||||
|
### Models
|
||||||
|
|
||||||
|
The following models are available by default:
|
||||||
|
|
||||||
|
{% for model in default_models %}
|
||||||
|
- `{{ model.model_id }} ({{ model.provider_model_id }})`
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
|
||||||
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
|
Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/).
|
||||||
|
|
||||||
|
|
||||||
|
## Running Llama Stack with Cerebras
|
||||||
|
|
||||||
|
You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLAMA_STACK_PORT=5001
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
|
llamastack/distribution-{{ name }} \
|
||||||
|
--yaml-config /root/my-run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via Conda
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template cerebras --image-type conda
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port 5001 \
|
||||||
|
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||||
|
```
|
||||||
63
llama_stack/templates/cerebras/run.yaml
Normal file
63
llama_stack/templates/cerebras/run.yaml
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: cerebras
|
||||||
|
docker_image: null
|
||||||
|
conda_env: cerebras
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- inference
|
||||||
|
- memory
|
||||||
|
- safety
|
||||||
|
- telemetry
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: cerebras
|
||||||
|
provider_type: remote::cerebras
|
||||||
|
config:
|
||||||
|
base_url: https://api.cerebras.ai
|
||||||
|
api_key: ${env.CEREBRAS_API_KEY}
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config: {}
|
||||||
|
memory:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/faiss_store.db
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
namespace: null
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: llama3.1-8b
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||||
|
provider_id: null
|
||||||
|
provider_model_id: llama3.1-70b
|
||||||
|
shields:
|
||||||
|
- params: null
|
||||||
|
shield_id: meta-llama/Llama-Guard-3-8B
|
||||||
|
provider_id: null
|
||||||
|
provider_shield_id: null
|
||||||
|
memory_banks: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
eval_tasks: []
|
||||||
Loading…
Add table
Add a link
Reference in a new issue