diff --git a/README.md b/README.md
index 8e57292c3..0dfb1306d 100644
--- a/README.md
+++ b/README.md
@@ -80,6 +80,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
| :----: | :----: | :----: | :----: | :----: | :----: | :----: |
| Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
+| Cerebras | Single Node | | :heavy_check_mark: | | | |
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
@@ -95,6 +96,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
|:----------------: |:------------------------------------------: |:-----------------------: |
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
+| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) |
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) |
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) |
diff --git a/distributions/cerebras/build.yaml b/distributions/cerebras/build.yaml
new file mode 120000
index 000000000..bccbbcf60
--- /dev/null
+++ b/distributions/cerebras/build.yaml
@@ -0,0 +1 @@
+../../llama_stack/templates/cerebras/build.yaml
\ No newline at end of file
diff --git a/distributions/cerebras/compose.yaml b/distributions/cerebras/compose.yaml
new file mode 100644
index 000000000..f2e9a6f42
--- /dev/null
+++ b/distributions/cerebras/compose.yaml
@@ -0,0 +1,16 @@
+services:
+ llamastack:
+ image: llamastack/distribution-cerebras
+ network_mode: "host"
+ volumes:
+ - ~/.llama:/root/.llama
+ - ./run.yaml:/root/llamastack-run-cerebras.yaml
+ ports:
+ - "5000:5000"
+ entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-cerebras.yaml"
+ deploy:
+ restart_policy:
+ condition: on-failure
+ delay: 3s
+ max_attempts: 5
+ window: 60s
diff --git a/distributions/cerebras/run.yaml b/distributions/cerebras/run.yaml
new file mode 120000
index 000000000..9f9d20b4b
--- /dev/null
+++ b/distributions/cerebras/run.yaml
@@ -0,0 +1 @@
+../../llama_stack/templates/cerebras/run.yaml
\ No newline at end of file
diff --git a/distributions/dependencies.json b/distributions/dependencies.json
index eb306adc6..ae3608ed5 100644
--- a/distributions/dependencies.json
+++ b/distributions/dependencies.json
@@ -1,4 +1,163 @@
{
+ "tgi": [
+ "aiohttp",
+ "aiosqlite",
+ "autoevals",
+ "blobfile",
+ "chardet",
+ "chromadb-client",
+ "datasets",
+ "faiss-cpu",
+ "fastapi",
+ "fire",
+ "httpx",
+ "huggingface_hub",
+ "matplotlib",
+ "nltk",
+ "numpy",
+ "openai",
+ "pandas",
+ "pillow",
+ "psycopg2-binary",
+ "pypdf",
+ "redis",
+ "scikit-learn",
+ "scipy",
+ "sentencepiece",
+ "tqdm",
+ "transformers",
+ "uvicorn",
+ "sentence-transformers --no-deps",
+ "torch --index-url https://download.pytorch.org/whl/cpu"
+ ],
+ "remote-vllm": [
+ "aiosqlite",
+ "autoevals",
+ "blobfile",
+ "chardet",
+ "chromadb-client",
+ "datasets",
+ "faiss-cpu",
+ "fastapi",
+ "fire",
+ "httpx",
+ "matplotlib",
+ "nltk",
+ "numpy",
+ "openai",
+ "pandas",
+ "pillow",
+ "psycopg2-binary",
+ "pypdf",
+ "redis",
+ "scikit-learn",
+ "scipy",
+ "sentencepiece",
+ "tqdm",
+ "transformers",
+ "uvicorn",
+ "sentence-transformers --no-deps",
+ "torch --index-url https://download.pytorch.org/whl/cpu"
+ ],
+ "vllm-gpu": [
+ "aiosqlite",
+ "autoevals",
+ "blobfile",
+ "chardet",
+ "chromadb-client",
+ "datasets",
+ "faiss-cpu",
+ "fastapi",
+ "fire",
+ "httpx",
+ "matplotlib",
+ "nltk",
+ "numpy",
+ "openai",
+ "pandas",
+ "pillow",
+ "psycopg2-binary",
+ "pypdf",
+ "redis",
+ "scikit-learn",
+ "scipy",
+ "sentencepiece",
+ "tqdm",
+ "transformers",
+ "uvicorn",
+ "vllm",
+ "sentence-transformers --no-deps",
+ "torch --index-url https://download.pytorch.org/whl/cpu"
+ ],
+ "meta-reference-quantized-gpu": [
+ "accelerate",
+ "aiosqlite",
+ "blobfile",
+ "chardet",
+ "chromadb-client",
+ "fairscale",
+ "faiss-cpu",
+ "fastapi",
+ "fbgemm-gpu",
+ "fire",
+ "httpx",
+ "lm-format-enforcer",
+ "matplotlib",
+ "nltk",
+ "numpy",
+ "pandas",
+ "pillow",
+ "psycopg2-binary",
+ "pypdf",
+ "redis",
+ "scikit-learn",
+ "scipy",
+ "sentencepiece",
+ "torch",
+ "torchao==0.5.0",
+ "torchvision",
+ "tqdm",
+ "transformers",
+ "uvicorn",
+ "zmq",
+ "sentence-transformers --no-deps",
+ "torch --index-url https://download.pytorch.org/whl/cpu"
+ ],
+ "meta-reference-gpu": [
+ "accelerate",
+ "aiosqlite",
+ "autoevals",
+ "blobfile",
+ "chardet",
+ "chromadb-client",
+ "datasets",
+ "fairscale",
+ "faiss-cpu",
+ "fastapi",
+ "fire",
+ "httpx",
+ "lm-format-enforcer",
+ "matplotlib",
+ "nltk",
+ "numpy",
+ "openai",
+ "pandas",
+ "pillow",
+ "psycopg2-binary",
+ "pypdf",
+ "redis",
+ "scikit-learn",
+ "scipy",
+ "sentencepiece",
+ "torch",
+ "torchvision",
+ "tqdm",
+ "transformers",
+ "uvicorn",
+ "zmq",
+ "sentence-transformers --no-deps",
+ "torch --index-url https://download.pytorch.org/whl/cpu"
+ ],
"hf-serverless": [
"aiohttp",
"aiosqlite",
@@ -60,94 +219,7 @@
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
- "vllm-gpu": [
- "aiosqlite",
- "autoevals",
- "blobfile",
- "chardet",
- "chromadb-client",
- "datasets",
- "faiss-cpu",
- "fastapi",
- "fire",
- "httpx",
- "matplotlib",
- "nltk",
- "numpy",
- "openai",
- "pandas",
- "pillow",
- "psycopg2-binary",
- "pypdf",
- "redis",
- "scikit-learn",
- "scipy",
- "sentencepiece",
- "tqdm",
- "transformers",
- "uvicorn",
- "vllm",
- "sentence-transformers --no-deps",
- "torch --index-url https://download.pytorch.org/whl/cpu"
- ],
- "remote-vllm": [
- "aiosqlite",
- "blobfile",
- "chardet",
- "chromadb-client",
- "faiss-cpu",
- "fastapi",
- "fire",
- "httpx",
- "matplotlib",
- "nltk",
- "numpy",
- "openai",
- "pandas",
- "pillow",
- "psycopg2-binary",
- "pypdf",
- "redis",
- "scikit-learn",
- "scipy",
- "sentencepiece",
- "tqdm",
- "transformers",
- "uvicorn",
- "sentence-transformers --no-deps",
- "torch --index-url https://download.pytorch.org/whl/cpu"
- ],
- "fireworks": [
- "aiosqlite",
- "autoevals",
- "blobfile",
- "chardet",
- "chromadb-client",
- "datasets",
- "faiss-cpu",
- "fastapi",
- "fire",
- "fireworks-ai",
- "httpx",
- "matplotlib",
- "nltk",
- "numpy",
- "openai",
- "pandas",
- "pillow",
- "psycopg2-binary",
- "pypdf",
- "redis",
- "scikit-learn",
- "scipy",
- "sentencepiece",
- "tqdm",
- "transformers",
- "uvicorn",
- "sentence-transformers --no-deps",
- "torch --index-url https://download.pytorch.org/whl/cpu"
- ],
- "tgi": [
+ "ollama": [
"aiohttp",
"aiosqlite",
"autoevals",
@@ -155,15 +227,16 @@
"chardet",
"chromadb-client",
"datasets",
+ "fairscale",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
- "huggingface_hub",
"matplotlib",
"nltk",
"numpy",
"openai",
+ "ollama",
"pandas",
"pillow",
"psycopg2-binary",
@@ -186,42 +259,11 @@
"chardet",
"chromadb-client",
"datasets",
- "faiss-cpu",
- "fastapi",
- "fire",
- "httpx",
- "matplotlib",
- "nltk",
- "numpy",
- "openai",
- "pandas",
- "pillow",
- "psycopg2-binary",
- "pypdf",
- "redis",
- "scikit-learn",
- "scipy",
- "sentencepiece",
- "tqdm",
- "transformers",
- "uvicorn",
- "sentence-transformers --no-deps",
- "torch --index-url https://download.pytorch.org/whl/cpu"
- ],
- "meta-reference-gpu": [
- "accelerate",
- "aiosqlite",
- "autoevals",
- "blobfile",
- "chardet",
- "chromadb-client",
- "datasets",
"fairscale",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
- "lm-format-enforcer",
"matplotlib",
"nltk",
"numpy",
@@ -234,77 +276,6 @@
"scikit-learn",
"scipy",
"sentencepiece",
- "torch",
- "torchvision",
- "tqdm",
- "transformers",
- "uvicorn",
- "zmq",
- "sentence-transformers --no-deps",
- "torch --index-url https://download.pytorch.org/whl/cpu"
- ],
- "meta-reference-quantized-gpu": [
- "accelerate",
- "aiosqlite",
- "autoevals",
- "blobfile",
- "chardet",
- "chromadb-client",
- "datasets",
- "fairscale",
- "faiss-cpu",
- "fastapi",
- "fbgemm-gpu",
- "fire",
- "httpx",
- "lm-format-enforcer",
- "matplotlib",
- "nltk",
- "numpy",
- "openai",
- "pandas",
- "pillow",
- "psycopg2-binary",
- "pypdf",
- "redis",
- "scikit-learn",
- "scipy",
- "sentencepiece",
- "torch",
- "torchao==0.5.0",
- "torchvision",
- "tqdm",
- "transformers",
- "uvicorn",
- "zmq",
- "sentence-transformers --no-deps",
- "torch --index-url https://download.pytorch.org/whl/cpu"
- ],
- "ollama": [
- "aiohttp",
- "aiosqlite",
- "autoevals",
- "blobfile",
- "chardet",
- "chromadb-client",
- "datasets",
- "faiss-cpu",
- "fastapi",
- "fire",
- "httpx",
- "matplotlib",
- "nltk",
- "numpy",
- "ollama",
- "openai",
- "pandas",
- "pillow",
- "psycopg2-binary",
- "pypdf",
- "redis",
- "scikit-learn",
- "scipy",
- "sentencepiece",
"tqdm",
"transformers",
"uvicorn",
@@ -327,6 +298,63 @@
"matplotlib",
"nltk",
"numpy",
+ "ollama",
+ "openai",
+ "pandas",
+ "pillow",
+ "psycopg2-binary",
+ "pypdf",
+ "redis",
+ "scikit-learn",
+ "scipy",
+ "sentencepiece",
+ "tqdm",
+ "transformers",
+ "uvicorn",
+ "sentence-transformers --no-deps",
+ "torch --index-url https://download.pytorch.org/whl/cpu"
+ ],
+ "fireworks": [
+ "aiosqlite",
+ "autoevals",
+ "blobfile",
+ "chardet",
+ "chromadb-client",
+ "datasets",
+ "faiss-cpu",
+ "fastapi",
+ "fire",
+ "fireworks-ai",
+ "httpx",
+ "matplotlib",
+ "nltk",
+ "numpy",
+ "pandas",
+ "pillow",
+ "psycopg2-binary",
+ "pypdf",
+ "redis",
+ "scikit-learn",
+ "scipy",
+ "sentencepiece",
+ "tqdm",
+ "transformers",
+ "uvicorn",
+ "sentence-transformers --no-deps",
+ "torch --index-url https://download.pytorch.org/whl/cpu"
+ ],
+ "cerebras": [
+ "aiosqlite",
+ "blobfile",
+ "cerebras_cloud_sdk",
+ "chardet",
+ "faiss-cpu",
+ "fastapi",
+ "fire",
+ "httpx",
+ "matplotlib",
+ "nltk",
+ "numpy",
"openai",
"pandas",
"pillow",
diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 090253804..4f220ea1e 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -2291,6 +2291,39 @@
"required": true
}
}
+ },
+ "/alpha/datasets/unregister": {
+ "post": {
+ "responses": {
+ "200": {
+ "description": "OK"
+ }
+ },
+ "tags": [
+ "Datasets"
+ ],
+ "parameters": [
+ {
+ "name": "X-LlamaStack-ProviderData",
+ "in": "header",
+ "description": "JSON-encoded provider data which will be made available to the adapter servicing the API",
+ "required": false,
+ "schema": {
+ "type": "string"
+ }
+ }
+ ],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/UnregisterDatasetRequest"
+ }
+ }
+ },
+ "required": true
+ }
+ }
}
},
"jsonSchemaDialect": "https://json-schema.org/draft/2020-12/schema",
@@ -7917,6 +7950,18 @@
"required": [
"model_id"
]
+ },
+ "UnregisterDatasetRequest": {
+ "type": "object",
+ "properties": {
+ "dataset_id": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "dataset_id"
+ ]
}
},
"responses": {}
@@ -8529,6 +8574,10 @@
"name": "UnregisterModelRequest",
"description": ""
},
+ {
+ "name": "UnregisterDatasetRequest",
+ "description": ""
+ },
{
"name": "UnstructuredLogEvent",
"description": ""
@@ -8718,6 +8767,7 @@
"URL",
"UnregisterMemoryBankRequest",
"UnregisterModelRequest",
+ "UnregisterDatasetRequest",
"UnstructuredLogEvent",
"UserMessage",
"VectorMemoryBank",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 8ffd9fdef..6564ddf3f 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -3253,6 +3253,14 @@ components:
required:
- model_id
type: object
+ UnregisterDatasetRequest:
+ additionalProperties: false
+ properties:
+ dataset_id:
+ type: string
+ required:
+ - dataset_id
+ type: object
UnstructuredLogEvent:
additionalProperties: false
properties:
@@ -3789,6 +3797,27 @@ paths:
description: OK
tags:
- Datasets
+ /alpha/datasets/unregister:
+ post:
+ parameters:
+ - description: JSON-encoded provider data which will be made available to the
+ adapter servicing the API
+ in: header
+ name: X-LlamaStack-ProviderData
+ required: false
+ schema:
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/UnregisterDatasetRequest'
+ required: true
+ responses:
+ '200':
+ description: OK
+ tags:
+ - Datasets
/alpha/eval-tasks/get:
get:
parameters:
@@ -5242,6 +5271,9 @@ tags:
- description:
name: UnregisterModelRequest
+- description:
+ name: UnregisterDatasetRequest
- description:
name: UnstructuredLogEvent
@@ -5418,6 +5450,7 @@ x-tagGroups:
- URL
- UnregisterMemoryBankRequest
- UnregisterModelRequest
+ - UnregisterDatasetRequest
- UnstructuredLogEvent
- UserMessage
- VectorMemoryBank
diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md
index a45d07ebf..67d39159c 100644
--- a/docs/source/distributions/building_distro.md
+++ b/docs/source/distributions/building_distro.md
@@ -66,121 +66,247 @@ llama stack build --list-templates
```
```
-+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
-| Template Name | Providers | Description |
-+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
-| hf-serverless | { | Like local, but use Hugging Face Inference API (serverless) for running LLM |
-| | "inference": "remote::hf::serverless", | inference. |
-| | "memory": "meta-reference", | See https://hf.co/docs/api-inference. |
-| | "safety": "meta-reference", | |
-| | "agents": "meta-reference", | |
-| | "telemetry": "meta-reference" | |
-| | } | |
-+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
-| together | { | Use Together.ai for running LLM inference |
-| | "inference": "remote::together", | |
-| | "memory": [ | |
-| | "meta-reference", | |
-| | "remote::weaviate" | |
-| | ], | |
-| | "safety": "meta-reference", | |
-| | "agents": "meta-reference", | |
-| | "telemetry": "meta-reference" | |
-| | } | |
-+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
-| fireworks | { | Use Fireworks.ai for running LLM inference |
-| | "inference": "remote::fireworks", | |
-| | "memory": [ | |
-| | "meta-reference", | |
-| | "remote::weaviate", | |
-| | "remote::chromadb", | |
-| | "remote::pgvector" | |
-| | ], | |
-| | "safety": "meta-reference", | |
-| | "agents": "meta-reference", | |
-| | "telemetry": "meta-reference" | |
-| | } | |
-+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
-| databricks | { | Use Databricks for running LLM inference |
-| | "inference": "remote::databricks", | |
-| | "memory": "meta-reference", | |
-| | "safety": "meta-reference", | |
-| | "agents": "meta-reference", | |
-| | "telemetry": "meta-reference" | |
-| | } | |
-+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
-| vllm | { | Like local, but use vLLM for running LLM inference |
-| | "inference": "vllm", | |
-| | "memory": "meta-reference", | |
-| | "safety": "meta-reference", | |
-| | "agents": "meta-reference", | |
-| | "telemetry": "meta-reference" | |
-| | } | |
-+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
-| tgi | { | Use TGI for running LLM inference |
-| | "inference": "remote::tgi", | |
-| | "memory": [ | |
-| | "meta-reference", | |
-| | "remote::chromadb", | |
-| | "remote::pgvector" | |
-| | ], | |
-| | "safety": "meta-reference", | |
-| | "agents": "meta-reference", | |
-| | "telemetry": "meta-reference" | |
-| | } | |
-+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
-| bedrock | { | Use Amazon Bedrock APIs. |
-| | "inference": "remote::bedrock", | |
-| | "memory": "meta-reference", | |
-| | "safety": "meta-reference", | |
-| | "agents": "meta-reference", | |
-| | "telemetry": "meta-reference" | |
-| | } | |
-+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
-| meta-reference-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs |
-| | "inference": "meta-reference", | |
-| | "memory": [ | |
-| | "meta-reference", | |
-| | "remote::chromadb", | |
-| | "remote::pgvector" | |
-| | ], | |
-| | "safety": "meta-reference", | |
-| | "agents": "meta-reference", | |
-| | "telemetry": "meta-reference" | |
-| | } | |
-+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
-| meta-reference-quantized-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs |
-| | "inference": "meta-reference-quantized", | |
-| | "memory": [ | |
-| | "meta-reference", | |
-| | "remote::chromadb", | |
-| | "remote::pgvector" | |
-| | ], | |
-| | "safety": "meta-reference", | |
-| | "agents": "meta-reference", | |
-| | "telemetry": "meta-reference" | |
-| | } | |
-+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
-| ollama | { | Use ollama for running LLM inference |
-| | "inference": "remote::ollama", | |
-| | "memory": [ | |
-| | "meta-reference", | |
-| | "remote::chromadb", | |
-| | "remote::pgvector" | |
-| | ], | |
-| | "safety": "meta-reference", | |
-| | "agents": "meta-reference", | |
-| | "telemetry": "meta-reference" | |
-| | } | |
-+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
-| hf-endpoint | { | Like local, but use Hugging Face Inference Endpoints for running LLM inference. |
-| | "inference": "remote::hf::endpoint", | See https://hf.co/docs/api-endpoints. |
-| | "memory": "meta-reference", | |
-| | "safety": "meta-reference", | |
-| | "agents": "meta-reference", | |
-| | "telemetry": "meta-reference" | |
-| | } | |
-+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
+| Template Name | Providers | Description |
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
+| tgi | { | Use (an external) TGI server for running LLM inference |
+| | "inference": [ | |
+| | "remote::tgi" | |
+| | ], | |
+| | "memory": [ | |
+| | "inline::faiss", | |
+| | "remote::chromadb", | |
+| | "remote::pgvector" | |
+| | ], | |
+| | "safety": [ | |
+| | "inline::llama-guard" | |
+| | ], | |
+| | "agents": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "telemetry": [ | |
+| | "inline::meta-reference" | |
+| | ] | |
+| | } | |
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
+| remote-vllm | { | Use (an external) vLLM server for running LLM inference |
+| | "inference": [ | |
+| | "remote::vllm" | |
+| | ], | |
+| | "memory": [ | |
+| | "inline::faiss", | |
+| | "remote::chromadb", | |
+| | "remote::pgvector" | |
+| | ], | |
+| | "safety": [ | |
+| | "inline::llama-guard" | |
+| | ], | |
+| | "agents": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "telemetry": [ | |
+| | "inline::meta-reference" | |
+| | ] | |
+| | } | |
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
+| vllm-gpu | { | Use a built-in vLLM engine for running LLM inference |
+| | "inference": [ | |
+| | "inline::vllm" | |
+| | ], | |
+| | "memory": [ | |
+| | "inline::faiss", | |
+| | "remote::chromadb", | |
+| | "remote::pgvector" | |
+| | ], | |
+| | "safety": [ | |
+| | "inline::llama-guard" | |
+| | ], | |
+| | "agents": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "telemetry": [ | |
+| | "inline::meta-reference" | |
+| | ] | |
+| | } | |
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
+| meta-reference-quantized-gpu | { | Use Meta Reference with fp8, int4 quantization for running LLM inference |
+| | "inference": [ | |
+| | "inline::meta-reference-quantized" | |
+| | ], | |
+| | "memory": [ | |
+| | "inline::faiss", | |
+| | "remote::chromadb", | |
+| | "remote::pgvector" | |
+| | ], | |
+| | "safety": [ | |
+| | "inline::llama-guard" | |
+| | ], | |
+| | "agents": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "telemetry": [ | |
+| | "inline::meta-reference" | |
+| | ] | |
+| | } | |
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
+| meta-reference-gpu | { | Use Meta Reference for running LLM inference |
+| | "inference": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "memory": [ | |
+| | "inline::faiss", | |
+| | "remote::chromadb", | |
+| | "remote::pgvector" | |
+| | ], | |
+| | "safety": [ | |
+| | "inline::llama-guard" | |
+| | ], | |
+| | "agents": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "telemetry": [ | |
+| | "inline::meta-reference" | |
+| | ] | |
+| | } | |
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
+| hf-serverless | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference |
+| | "inference": [ | |
+| | "remote::hf::serverless" | |
+| | ], | |
+| | "memory": [ | |
+| | "inline::faiss", | |
+| | "remote::chromadb", | |
+| | "remote::pgvector" | |
+| | ], | |
+| | "safety": [ | |
+| | "inline::llama-guard" | |
+| | ], | |
+| | "agents": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "telemetry": [ | |
+| | "inline::meta-reference" | |
+| | ] | |
+| | } | |
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
+| together | { | Use Together.AI for running LLM inference |
+| | "inference": [ | |
+| | "remote::together" | |
+| | ], | |
+| | "memory": [ | |
+| | "inline::faiss", | |
+| | "remote::chromadb", | |
+| | "remote::pgvector" | |
+| | ], | |
+| | "safety": [ | |
+| | "inline::llama-guard" | |
+| | ], | |
+| | "agents": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "telemetry": [ | |
+| | "inline::meta-reference" | |
+| | ] | |
+| | } | |
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
+| ollama | { | Use (an external) Ollama server for running LLM inference |
+| | "inference": [ | |
+| | "remote::ollama" | |
+| | ], | |
+| | "memory": [ | |
+| | "inline::faiss", | |
+| | "remote::chromadb", | |
+| | "remote::pgvector" | |
+| | ], | |
+| | "safety": [ | |
+| | "inline::llama-guard" | |
+| | ], | |
+| | "agents": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "telemetry": [ | |
+| | "inline::meta-reference" | |
+| | ] | |
+| | } | |
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
+| bedrock | { | Use AWS Bedrock for running LLM inference and safety |
+| | "inference": [ | |
+| | "remote::bedrock" | |
+| | ], | |
+| | "memory": [ | |
+| | "inline::faiss", | |
+| | "remote::chromadb", | |
+| | "remote::pgvector" | |
+| | ], | |
+| | "safety": [ | |
+| | "remote::bedrock" | |
+| | ], | |
+| | "agents": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "telemetry": [ | |
+| | "inline::meta-reference" | |
+| | ] | |
+| | } | |
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
+| hf-endpoint | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference |
+| | "inference": [ | |
+| | "remote::hf::endpoint" | |
+| | ], | |
+| | "memory": [ | |
+| | "inline::faiss", | |
+| | "remote::chromadb", | |
+| | "remote::pgvector" | |
+| | ], | |
+| | "safety": [ | |
+| | "inline::llama-guard" | |
+| | ], | |
+| | "agents": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "telemetry": [ | |
+| | "inline::meta-reference" | |
+| | ] | |
+| | } | |
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
+| fireworks | { | Use Fireworks.AI for running LLM inference |
+| | "inference": [ | |
+| | "remote::fireworks" | |
+| | ], | |
+| | "memory": [ | |
+| | "inline::faiss", | |
+| | "remote::chromadb", | |
+| | "remote::pgvector" | |
+| | ], | |
+| | "safety": [ | |
+| | "inline::llama-guard" | |
+| | ], | |
+| | "agents": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "telemetry": [ | |
+| | "inline::meta-reference" | |
+| | ] | |
+| | } | |
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
+| cerebras | { | Use Cerebras for running LLM inference |
+| | "inference": [ | |
+| | "remote::cerebras" | |
+| | ], | |
+| | "safety": [ | |
+| | "inline::llama-guard" | |
+| | ], | |
+| | "memory": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "agents": [ | |
+| | "inline::meta-reference" | |
+| | ], | |
+| | "telemetry": [ | |
+| | "inline::meta-reference" | |
+| | ] | |
+| | } | |
++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
```
You may then pick a template to build your distribution with providers fitted to your liking.
diff --git a/docs/source/distributions/importing_as_library.md b/docs/source/distributions/importing_as_library.md
index 815660fd4..7e15062df 100644
--- a/docs/source/distributions/importing_as_library.md
+++ b/docs/source/distributions/importing_as_library.md
@@ -21,7 +21,7 @@ print(response)
```python
response = await client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
- model="Llama3.1-8B-Instruct",
+ model_id="Llama3.1-8B-Instruct",
stream=False,
)
print("\nChat completion response:")
diff --git a/docs/source/distributions/self_hosted_distro/cerebras.md b/docs/source/distributions/self_hosted_distro/cerebras.md
new file mode 100644
index 000000000..08b35809a
--- /dev/null
+++ b/docs/source/distributions/self_hosted_distro/cerebras.md
@@ -0,0 +1,61 @@
+# Cerebras Distribution
+
+The `llamastack/distribution-cerebras` distribution consists of the following provider configurations.
+
+| API | Provider(s) |
+|-----|-------------|
+| agents | `inline::meta-reference` |
+| inference | `remote::cerebras` |
+| memory | `inline::meta-reference` |
+| safety | `inline::llama-guard` |
+| telemetry | `inline::meta-reference` |
+
+
+### Environment Variables
+
+The following environment variables can be configured:
+
+- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
+- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
+
+### Models
+
+The following models are available by default:
+
+- `meta-llama/Llama-3.1-8B-Instruct (llama3.1-8b)`
+- `meta-llama/Llama-3.1-70B-Instruct (llama3.1-70b)`
+
+
+### Prerequisite: API Keys
+
+Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/).
+
+
+## Running Llama Stack with Cerebras
+
+You can do this via Conda (build code) or Docker which has a pre-built image.
+
+### Via Docker
+
+This method allows you to get started quickly without having to build the distribution code.
+
+```bash
+LLAMA_STACK_PORT=5001
+docker run \
+ -it \
+ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
+ -v ./run.yaml:/root/my-run.yaml \
+ llamastack/distribution-cerebras \
+ --yaml-config /root/my-run.yaml \
+ --port $LLAMA_STACK_PORT \
+ --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
+```
+
+### Via Conda
+
+```bash
+llama stack build --template cerebras --image-type conda
+llama stack run ./run.yaml \
+ --port 5001 \
+ --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
+```
diff --git a/docs/source/index.md b/docs/source/index.md
index 291237843..abfaf51b4 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -45,6 +45,7 @@ Llama Stack already has a number of "adapters" available for some popular Infere
| **API Provider** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
| :----: | :----: | :----: | :----: | :----: | :----: | :----: |
| Meta Reference | Single Node | Y | Y | Y | Y | Y |
+| Cerebras | Single Node | | Y | | | |
| Fireworks | Hosted | Y | Y | Y | | |
| AWS Bedrock | Hosted | | Y | | Y | |
| Together | Hosted | Y | Y | | Y | |
diff --git a/llama_stack/apis/datasets/client.py b/llama_stack/apis/datasets/client.py
index 9e5891e74..c379a49fb 100644
--- a/llama_stack/apis/datasets/client.py
+++ b/llama_stack/apis/datasets/client.py
@@ -78,6 +78,21 @@ class DatasetsClient(Datasets):
return [DatasetDefWithProvider(**x) for x in response.json()]
+ async def unregister_dataset(
+ self,
+ dataset_id: str,
+ ) -> None:
+ async with httpx.AsyncClient() as client:
+ response = await client.delete(
+ f"{self.base_url}/datasets/unregister",
+ params={
+ "dataset_id": dataset_id,
+ },
+ headers={"Content-Type": "application/json"},
+ timeout=60,
+ )
+ response.raise_for_status()
+
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")
diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py
index 2ab958782..e1ac4af21 100644
--- a/llama_stack/apis/datasets/datasets.py
+++ b/llama_stack/apis/datasets/datasets.py
@@ -64,3 +64,9 @@ class Datasets(Protocol):
@webmethod(route="/datasets/list", method="GET")
async def list_datasets(self) -> List[Dataset]: ...
+
+ @webmethod(route="/datasets/unregister", method="POST")
+ async def unregister_dataset(
+ self,
+ dataset_id: str,
+ ) -> None: ...
diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py
index 4df693b26..2fb5a5e1c 100644
--- a/llama_stack/distribution/routers/routing_tables.py
+++ b/llama_stack/distribution/routers/routing_tables.py
@@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_memory_bank(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
+ elif api == Api.datasetio:
+ return await p.unregister_dataset(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
@@ -354,6 +356,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
)
await self.register_object(dataset)
+ async def unregister_dataset(self, dataset_id: str) -> None:
+ dataset = await self.get_dataset(dataset_id)
+ if dataset is None:
+ raise ValueError(f"Dataset {dataset_id} not found")
+ await self.unregister_object(dataset)
+
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]:
diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py
index 080204e45..8e89bcc72 100644
--- a/llama_stack/providers/datatypes.py
+++ b/llama_stack/providers/datatypes.py
@@ -63,6 +63,8 @@ class MemoryBanksProtocolPrivate(Protocol):
class DatasetsProtocolPrivate(Protocol):
async def register_dataset(self, dataset: Dataset) -> None: ...
+ async def unregister_dataset(self, dataset_id: str) -> None: ...
+
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFn]: ...
diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py
index 4de1850ae..010610056 100644
--- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py
+++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py
@@ -97,6 +97,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl=dataset_impl,
)
+ async def unregister_dataset(self, dataset_id: str) -> None:
+ del self.dataset_infos[dataset_id]
+
async def get_rows_paginated(
self,
dataset_id: str,
diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py
index c8d061f6c..13d463ad8 100644
--- a/llama_stack/providers/registry/inference.py
+++ b/llama_stack/providers/registry/inference.py
@@ -61,6 +61,17 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
),
),
+ remote_provider_spec(
+ api=Api.inference,
+ adapter=AdapterSpec(
+ adapter_type="cerebras",
+ pip_packages=[
+ "cerebras_cloud_sdk",
+ ],
+ module="llama_stack.providers.remote.inference.cerebras",
+ config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
+ ),
+ ),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py
index c2e4506bf..cdd5d9cd3 100644
--- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py
+++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py
@@ -64,6 +64,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
)
self.dataset_infos[dataset_def.identifier] = dataset_def
+ async def unregister_dataset(self, dataset_id: str) -> None:
+ key = f"{DATASETS_PREFIX}{dataset_id}"
+ await self.kvstore.delete(key=key)
+ del self.dataset_infos[dataset_id]
+
async def get_rows_paginated(
self,
dataset_id: str,
diff --git a/llama_stack/providers/remote/inference/cerebras/__init__.py b/llama_stack/providers/remote/inference/cerebras/__init__.py
new file mode 100644
index 000000000..a24bb2c70
--- /dev/null
+++ b/llama_stack/providers/remote/inference/cerebras/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from .config import CerebrasImplConfig
+
+
+async def get_adapter_impl(config: CerebrasImplConfig, _deps):
+ from .cerebras import CerebrasInferenceAdapter
+
+ assert isinstance(
+ config, CerebrasImplConfig
+ ), f"Unexpected config type: {type(config)}"
+
+ impl = CerebrasInferenceAdapter(config)
+
+ await impl.initialize()
+
+ return impl
diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py
new file mode 100644
index 000000000..65022f85e
--- /dev/null
+++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py
@@ -0,0 +1,191 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from typing import AsyncGenerator
+
+from cerebras.cloud.sdk import AsyncCerebras
+
+from llama_models.llama3.api.chat_format import ChatFormat
+
+from llama_models.llama3.api.datatypes import Message
+from llama_models.llama3.api.tokenizer import Tokenizer
+
+from llama_stack.apis.inference import * # noqa: F403
+
+from llama_models.datatypes import CoreModelId
+
+from llama_stack.providers.utils.inference.model_registry import (
+ build_model_alias,
+ ModelRegistryHelper,
+)
+from llama_stack.providers.utils.inference.openai_compat import (
+ get_sampling_options,
+ process_chat_completion_response,
+ process_chat_completion_stream_response,
+ process_completion_response,
+ process_completion_stream_response,
+)
+from llama_stack.providers.utils.inference.prompt_adapter import (
+ chat_completion_request_to_prompt,
+ completion_request_to_prompt,
+)
+
+from .config import CerebrasImplConfig
+
+
+model_aliases = [
+ build_model_alias(
+ "llama3.1-8b",
+ CoreModelId.llama3_1_8b_instruct.value,
+ ),
+ build_model_alias(
+ "llama3.1-70b",
+ CoreModelId.llama3_1_70b_instruct.value,
+ ),
+]
+
+
+class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
+ def __init__(self, config: CerebrasImplConfig) -> None:
+ ModelRegistryHelper.__init__(
+ self,
+ model_aliases=model_aliases,
+ )
+ self.config = config
+ self.formatter = ChatFormat(Tokenizer.get_instance())
+
+ self.client = AsyncCerebras(
+ base_url=self.config.base_url, api_key=self.config.api_key
+ )
+
+ async def initialize(self) -> None:
+ return
+
+ async def shutdown(self) -> None:
+ pass
+
+ async def completion(
+ self,
+ model_id: str,
+ content: InterleavedTextMedia,
+ sampling_params: Optional[SamplingParams] = SamplingParams(),
+ response_format: Optional[ResponseFormat] = None,
+ stream: Optional[bool] = False,
+ logprobs: Optional[LogProbConfig] = None,
+ ) -> AsyncGenerator:
+ model = await self.model_store.get_model(model_id)
+ request = CompletionRequest(
+ model=model.provider_resource_id,
+ content=content,
+ sampling_params=sampling_params,
+ response_format=response_format,
+ stream=stream,
+ logprobs=logprobs,
+ )
+ if stream:
+ return self._stream_completion(
+ request,
+ )
+ else:
+ return await self._nonstream_completion(request)
+
+ async def _nonstream_completion(
+ self, request: CompletionRequest
+ ) -> CompletionResponse:
+ params = self._get_params(request)
+
+ r = await self.client.completions.create(**params)
+
+ return process_completion_response(r, self.formatter)
+
+ async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
+ params = self._get_params(request)
+
+ stream = await self.client.completions.create(**params)
+
+ async for chunk in process_completion_stream_response(stream, self.formatter):
+ yield chunk
+
+ async def chat_completion(
+ self,
+ model_id: str,
+ messages: List[Message],
+ sampling_params: Optional[SamplingParams] = SamplingParams(),
+ tools: Optional[List[ToolDefinition]] = None,
+ tool_choice: Optional[ToolChoice] = ToolChoice.auto,
+ tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ response_format: Optional[ResponseFormat] = None,
+ stream: Optional[bool] = False,
+ logprobs: Optional[LogProbConfig] = None,
+ ) -> AsyncGenerator:
+ model = await self.model_store.get_model(model_id)
+ request = ChatCompletionRequest(
+ model=model.provider_resource_id,
+ messages=messages,
+ sampling_params=sampling_params,
+ tools=tools or [],
+ tool_choice=tool_choice,
+ tool_prompt_format=tool_prompt_format,
+ response_format=response_format,
+ stream=stream,
+ logprobs=logprobs,
+ )
+
+ if stream:
+ return self._stream_chat_completion(request)
+ else:
+ return await self._nonstream_chat_completion(request)
+
+ async def _nonstream_chat_completion(
+ self, request: CompletionRequest
+ ) -> CompletionResponse:
+ params = self._get_params(request)
+
+ r = await self.client.completions.create(**params)
+
+ return process_chat_completion_response(r, self.formatter)
+
+ async def _stream_chat_completion(
+ self, request: CompletionRequest
+ ) -> AsyncGenerator:
+ params = self._get_params(request)
+
+ stream = await self.client.completions.create(**params)
+
+ async for chunk in process_chat_completion_stream_response(
+ stream, self.formatter
+ ):
+ yield chunk
+
+ def _get_params(
+ self, request: Union[ChatCompletionRequest, CompletionRequest]
+ ) -> dict:
+ if request.sampling_params and request.sampling_params.top_k:
+ raise ValueError("`top_k` not supported by Cerebras")
+
+ prompt = ""
+ if type(request) == ChatCompletionRequest:
+ prompt = chat_completion_request_to_prompt(
+ request, self.get_llama_model(request.model), self.formatter
+ )
+ elif type(request) == CompletionRequest:
+ prompt = completion_request_to_prompt(request, self.formatter)
+ else:
+ raise ValueError(f"Unknown request type {type(request)}")
+
+ return {
+ "model": request.model,
+ "prompt": prompt,
+ "stream": request.stream,
+ **get_sampling_options(request.sampling_params),
+ }
+
+ async def embeddings(
+ self,
+ model_id: str,
+ contents: List[InterleavedTextMedia],
+ ) -> EmbeddingsResponse:
+ raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/cerebras/config.py b/llama_stack/providers/remote/inference/cerebras/config.py
new file mode 100644
index 000000000..9bae6ca4d
--- /dev/null
+++ b/llama_stack/providers/remote/inference/cerebras/config.py
@@ -0,0 +1,32 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import os
+from typing import Any, Dict, Optional
+
+from llama_models.schema_utils import json_schema_type
+from pydantic import BaseModel, Field
+
+DEFAULT_BASE_URL = "https://api.cerebras.ai"
+
+
+@json_schema_type
+class CerebrasImplConfig(BaseModel):
+ base_url: str = Field(
+ default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
+ description="Base URL for the Cerebras API",
+ )
+ api_key: Optional[str] = Field(
+ default=os.environ.get("CEREBRAS_API_KEY"),
+ description="Cerebras API Key",
+ )
+
+ @classmethod
+ def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
+ return {
+ "base_url": DEFAULT_BASE_URL,
+ "api_key": "${env.CEREBRAS_API_KEY}",
+ }
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index 74c0b8601..f89629afc 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -180,7 +180,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
r = await self.client.generate(**params)
- assert isinstance(r, dict)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py
index dd2cbd019..7d88b6115 100644
--- a/llama_stack/providers/tests/datasetio/test_datasetio.py
+++ b/llama_stack/providers/tests/datasetio/test_datasetio.py
@@ -81,6 +81,18 @@ class TestDatasetIO:
assert len(response) == 1
assert response[0].identifier == "test_dataset"
+ with pytest.raises(Exception) as exc_info:
+ # unregister a dataset that does not exist
+ await datasets_impl.unregister_dataset("test_dataset2")
+
+ await datasets_impl.unregister_dataset("test_dataset")
+ response = await datasets_impl.list_datasets()
+ assert isinstance(response, list)
+ assert len(response) == 0
+
+ with pytest.raises(Exception) as exc_info:
+ await datasets_impl.unregister_dataset("test_dataset")
+
@pytest.mark.asyncio
async def test_get_rows_paginated(self, datasetio_stack):
datasetio_impl, datasets_impl = datasetio_stack
diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py
index a427eef12..21e122149 100644
--- a/llama_stack/providers/tests/inference/fixtures.py
+++ b/llama_stack/providers/tests/inference/fixtures.py
@@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.meta_reference import (
)
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
+from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
@@ -64,6 +65,21 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
)
+@pytest.fixture(scope="session")
+def inference_cerebras() -> ProviderFixture:
+ return ProviderFixture(
+ providers=[
+ Provider(
+ provider_id="cerebras",
+ provider_type="remote::cerebras",
+ config=CerebrasImplConfig(
+ api_key=get_env_or_fail("CEREBRAS_API_KEY"),
+ ).model_dump(),
+ )
+ ],
+ )
+
+
@pytest.fixture(scope="session")
def inference_ollama(inference_model) -> ProviderFixture:
inference_model = (
@@ -206,6 +222,7 @@ INFERENCE_FIXTURES = [
"vllm_remote",
"remote",
"bedrock",
+ "cerebras",
"nvidia",
"tgi",
]
diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py
index 9e5c67375..aa2f0b413 100644
--- a/llama_stack/providers/tests/inference/test_text_inference.py
+++ b/llama_stack/providers/tests/inference/test_text_inference.py
@@ -94,6 +94,7 @@ class TestInference:
"remote::tgi",
"remote::together",
"remote::fireworks",
+ "remote::cerebras",
):
pytest.skip("Other inference providers don't support completion() yet")
@@ -139,6 +140,7 @@ class TestInference:
"remote::tgi",
"remote::together",
"remote::fireworks",
+ "remote::cerebras",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
diff --git a/llama_stack/templates/cerebras/__init__.py b/llama_stack/templates/cerebras/__init__.py
new file mode 100644
index 000000000..9f9929b52
--- /dev/null
+++ b/llama_stack/templates/cerebras/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from .cerebras import get_distribution_template # noqa: F401
diff --git a/llama_stack/templates/cerebras/build.yaml b/llama_stack/templates/cerebras/build.yaml
new file mode 100644
index 000000000..a1fe93099
--- /dev/null
+++ b/llama_stack/templates/cerebras/build.yaml
@@ -0,0 +1,17 @@
+version: '2'
+name: cerebras
+distribution_spec:
+ description: Use Cerebras for running LLM inference
+ docker_image: null
+ providers:
+ inference:
+ - remote::cerebras
+ safety:
+ - inline::llama-guard
+ memory:
+ - inline::meta-reference
+ agents:
+ - inline::meta-reference
+ telemetry:
+ - inline::meta-reference
+image_type: conda
diff --git a/llama_stack/templates/cerebras/cerebras.py b/llama_stack/templates/cerebras/cerebras.py
new file mode 100644
index 000000000..58e05adf8
--- /dev/null
+++ b/llama_stack/templates/cerebras/cerebras.py
@@ -0,0 +1,71 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from pathlib import Path
+
+from llama_models.sku_list import all_registered_models
+
+from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
+from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
+from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases
+
+from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
+
+
+def get_distribution_template() -> DistributionTemplate:
+ providers = {
+ "inference": ["remote::cerebras"],
+ "safety": ["inline::llama-guard"],
+ "memory": ["inline::meta-reference"],
+ "agents": ["inline::meta-reference"],
+ "telemetry": ["inline::meta-reference"],
+ }
+
+ inference_provider = Provider(
+ provider_id="cerebras",
+ provider_type="remote::cerebras",
+ config=CerebrasImplConfig.sample_run_config(),
+ )
+
+ core_model_to_hf_repo = {
+ m.descriptor(): m.huggingface_repo for m in all_registered_models()
+ }
+ default_models = [
+ ModelInput(
+ model_id=core_model_to_hf_repo[m.llama_model],
+ provider_model_id=m.provider_model_id,
+ )
+ for m in model_aliases
+ ]
+
+ return DistributionTemplate(
+ name="cerebras",
+ distro_type="self_hosted",
+ description="Use Cerebras for running LLM inference",
+ docker_image=None,
+ template_path=Path(__file__).parent / "doc_template.md",
+ providers=providers,
+ default_models=default_models,
+ run_configs={
+ "run.yaml": RunConfigSettings(
+ provider_overrides={
+ "inference": [inference_provider],
+ },
+ default_models=default_models,
+ default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
+ ),
+ },
+ run_config_env_vars={
+ "LLAMASTACK_PORT": (
+ "5001",
+ "Port for the Llama Stack distribution server",
+ ),
+ "CEREBRAS_API_KEY": (
+ "",
+ "Cerebras API Key",
+ ),
+ },
+ )
diff --git a/llama_stack/templates/cerebras/doc_template.md b/llama_stack/templates/cerebras/doc_template.md
new file mode 100644
index 000000000..77fc6f478
--- /dev/null
+++ b/llama_stack/templates/cerebras/doc_template.md
@@ -0,0 +1,60 @@
+# Cerebras Distribution
+
+The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
+
+{{ providers_table }}
+
+{% if run_config_env_vars %}
+### Environment Variables
+
+The following environment variables can be configured:
+
+{% for var, (default_value, description) in run_config_env_vars.items() %}
+- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
+{% endfor %}
+{% endif %}
+
+{% if default_models %}
+### Models
+
+The following models are available by default:
+
+{% for model in default_models %}
+- `{{ model.model_id }} ({{ model.provider_model_id }})`
+{% endfor %}
+{% endif %}
+
+
+### Prerequisite: API Keys
+
+Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/).
+
+
+## Running Llama Stack with Cerebras
+
+You can do this via Conda (build code) or Docker which has a pre-built image.
+
+### Via Docker
+
+This method allows you to get started quickly without having to build the distribution code.
+
+```bash
+LLAMA_STACK_PORT=5001
+docker run \
+ -it \
+ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
+ -v ./run.yaml:/root/my-run.yaml \
+ llamastack/distribution-{{ name }} \
+ --yaml-config /root/my-run.yaml \
+ --port $LLAMA_STACK_PORT \
+ --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
+```
+
+### Via Conda
+
+```bash
+llama stack build --template cerebras --image-type conda
+llama stack run ./run.yaml \
+ --port 5001 \
+ --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
+```
diff --git a/llama_stack/templates/cerebras/run.yaml b/llama_stack/templates/cerebras/run.yaml
new file mode 100644
index 000000000..0b41f5b76
--- /dev/null
+++ b/llama_stack/templates/cerebras/run.yaml
@@ -0,0 +1,63 @@
+version: '2'
+image_name: cerebras
+docker_image: null
+conda_env: cerebras
+apis:
+- agents
+- inference
+- memory
+- safety
+- telemetry
+providers:
+ inference:
+ - provider_id: cerebras
+ provider_type: remote::cerebras
+ config:
+ base_url: https://api.cerebras.ai
+ api_key: ${env.CEREBRAS_API_KEY}
+ safety:
+ - provider_id: llama-guard
+ provider_type: inline::llama-guard
+ config: {}
+ memory:
+ - provider_id: meta-reference
+ provider_type: inline::meta-reference
+ config:
+ kvstore:
+ type: sqlite
+ namespace: null
+ db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/faiss_store.db
+ agents:
+ - provider_id: meta-reference
+ provider_type: inline::meta-reference
+ config:
+ persistence_store:
+ type: sqlite
+ namespace: null
+ db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db
+ telemetry:
+ - provider_id: meta-reference
+ provider_type: inline::meta-reference
+ config: {}
+metadata_store:
+ namespace: null
+ type: sqlite
+ db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db
+models:
+- metadata: {}
+ model_id: meta-llama/Llama-3.1-8B-Instruct
+ provider_id: null
+ provider_model_id: llama3.1-8b
+- metadata: {}
+ model_id: meta-llama/Llama-3.1-70B-Instruct
+ provider_id: null
+ provider_model_id: llama3.1-70b
+shields:
+- params: null
+ shield_id: meta-llama/Llama-Guard-3-8B
+ provider_id: null
+ provider_shield_id: null
+memory_banks: []
+datasets: []
+scoring_fns: []
+eval_tasks: []