Cerebras Inference Integration (#265)

Adding Cerebras Inference as an API provider.

## Testing

### Conda
```
$ llama stack build --template cerebras --image-type conda
$ llama stack run ~/.llama/distributions/llamastack-cerebras/cerebras-run.yaml
...
Listening on ['::', '0.0.0.0']:5000
INFO:     Started server process [12443]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://['::', '0.0.0.0']:5000 (Press CTRL+C to quit)
```

### Chat Completion
```
$ curl --location 'http://localhost:5000/alpha/inference/chat-completion' --header 'Content-Type: application/json' --data '{
    "model_id": "meta-llama/Llama-3.1-8B-Instruct",
    "messages": [
        {
            "role": "user",
            "content": "What is the temperature in Seattle right now?"
        }
    ],
    "stream": false,
    "sampling_params": {
        "strategy": "top_p",
        "temperature": 0.5,
        "max_tokens": 100
    },                   
    "tool_choice": "auto",
    "tool_prompt_format": "json",
    "tools": [                   
        {
            "tool_name": "getTemperature",
            "description": "Gets the current temperature of a location.",
            "parameters": {                                              
                "location": {
                    "param_type": "string",
                    "description": "The name of the place to get the temperature from in degress celsius.",
                    "required": true                                                                       
                }                   
            }    
        }    
    ]    
}' 
```

#### Non-Streaming Response
```
{
  "completion_message": {
    "role": "assistant",
    "content": "",
    "stop_reason": "end_of_message",
    "tool_calls": [
      {
        "call_id": "6f42fdcc-6cbb-46ad-a17b-5d20ac64b678",
        "tool_name": "getTemperature",
        "arguments": {
          "location": "Seattle"
        }
      }
    ]
  },
  "logprobs": null
}
```

#### Streaming Response
```
data: {"event":{"event_type":"start","delta":"","logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"","parse_status":"started"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"{\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"type","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"function","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"name","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"get","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"Temperature","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"parameters","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":" {\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"location","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"Seattle","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":"\"}}","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}}
data: {"event":{"event_type":"progress","delta":{"content":{"call_id":"e742df1f-0ae9-40ad-a49e-18e5c905484f","tool_name":"getTemperature","arguments":{"location":"Seattle"}},"parse_status":"success"},"logprobs":null,"stop_reason":"end_of_message"}}
data: {"event":{"event_type":"complete","delta":"","logprobs":null,"stop_reason":"end_of_message"}}
```

### Completion
```
$ curl --location 'http://localhost:5000/alpha/inference/completion' --header 'Content-Type: application/json' --data '{
    "model_id": "meta-llama/Llama-3.1-8B-Instruct",
    "content": "1,2,3,",
    "stream": true,
    "sampling_params": {
        "strategy": "top_p",
        "temperature": 0.5,
        "max_tokens": 10
    },                   
    "tool_choice": "auto",
    "tool_prompt_format": "json",
    "tools": [                   
        {
            "tool_name": "getTemperature",
            "description": "Gets the current temperature of a location.",
            "parameters": {                                              
                "location": {
                    "param_type": "string",
                    "description": "The name of the place to get the temperature from in degress celsius.",
                    "required": true                                                                       
                }                   
            }    
        }    
    ]    
}'
```

#### Non-Streaming Response
```
{
  "content": "4,5,6,7,8,",
  "stop_reason": "out_of_tokens",
  "logprobs": null
}
```

#### Streaming Response
```
data: {"delta":"4","stop_reason":null,"logprobs":null}
data: {"delta":",","stop_reason":null,"logprobs":null}
data: {"delta":"5","stop_reason":null,"logprobs":null}
data: {"delta":",","stop_reason":null,"logprobs":null}
data: {"delta":"6","stop_reason":null,"logprobs":null}
data: {"delta":",","stop_reason":null,"logprobs":null}
data: {"delta":"7","stop_reason":null,"logprobs":null}
data: {"delta":",","stop_reason":null,"logprobs":null}
data: {"delta":"8","stop_reason":null,"logprobs":null}
data: {"delta":",","stop_reason":null,"logprobs":null}
data: {"delta":"","stop_reason":null,"logprobs":null}
data: {"delta":"","stop_reason":"out_of_tokens","logprobs":null}
```

### Pre-Commit Checks
```
trim trailing whitespace.................................................Passed
check python ast.........................................................Passed
check for merge conflicts................................................Passed
check for added large files..............................................Passed
fix end of files.........................................................Passed
Insert license in comments...............................................Passed
flake8...................................................................Passed
Format files with µfmt...................................................Passed
```

### Testing with `test_inference.py`
```
$ export CEREBRAS_API_KEY=<insert API key here>
$ pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py -m "cerebras and llama_8b" 
/net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/lib/python3.12/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
=================================================== test session starts ===================================================
platform linux -- Python 3.12.3, pytest-8.3.3, pluggy-1.5.0 -- /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/bin/python3.12
cachedir: .pytest_cache
rootdir: /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack
configfile: pyproject.toml
plugins: anyio-4.6.2.post1, asyncio-0.24.0
asyncio: mode=Mode.STRICT, default_loop_scope=None
collected 128 items / 120 deselected / 8 selected                                                                         

llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_8b-cerebras] Resolved 4 providers
 inner-inference => cerebras
 models => __routing_table__
 inference => __autorouted__
 inspect => __builtin__

Models: meta-llama/Llama-3.1-8B-Instruct served by cerebras

PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_8b-cerebras] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completions_structured_output[llama_8b-cerebras] SKIPPED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_8b-cerebras] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-cerebras] SKIPPED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_8b-cerebras] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_8b-cerebras] PASSED
llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_8b-cerebras] PASSED

================================ 6 passed, 2 skipped, 120 deselected, 6 warnings in 3.95s =================================
```

I ran `python llama_stack/scripts/distro_codegen.py` to run codegen.
This commit is contained in:
Henry Tu 2024-12-04 00:15:32 -05:00 committed by GitHub
parent b6500974ec
commit 64c6df8392
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 1018 additions and 292 deletions

View file

@ -80,6 +80,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
| :----: | :----: | :----: | :----: | :----: | :----: | :----: |
| Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| Cerebras | Single Node | | :heavy_check_mark: | | | |
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
@ -95,6 +96,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
|:----------------: |:------------------------------------------: |:-----------------------: |
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) |
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) |
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) |

View file

@ -0,0 +1 @@
../../llama_stack/templates/cerebras/build.yaml

View file

@ -0,0 +1,16 @@
services:
llamastack:
image: llamastack/distribution-cerebras
network_mode: "host"
volumes:
- ~/.llama:/root/.llama
- ./run.yaml:/root/llamastack-run-cerebras.yaml
ports:
- "5000:5000"
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-cerebras.yaml"
deploy:
restart_policy:
condition: on-failure
delay: 3s
max_attempts: 5
window: 60s

View file

@ -0,0 +1 @@
../../llama_stack/templates/cerebras/run.yaml

View file

@ -1,4 +1,152 @@
{
"tgi": [
"aiohttp",
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"remote-vllm": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"vllm-gpu": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"vllm",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"meta-reference-quantized-gpu": [
"accelerate",
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"fairscale",
"faiss-cpu",
"fastapi",
"fbgemm-gpu",
"fire",
"httpx",
"lm-format-enforcer",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"torch",
"torchao==0.5.0",
"torchvision",
"tqdm",
"transformers",
"uvicorn",
"zmq",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"meta-reference-gpu": [
"accelerate",
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"fairscale",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"lm-format-enforcer",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"torch",
"torchvision",
"tqdm",
"transformers",
"uvicorn",
"zmq",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"hf-serverless": [
"aiohttp",
"aiosqlite",
@ -54,88 +202,7 @@
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"vllm-gpu": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"vllm",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"remote-vllm": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"fireworks": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"tgi": [
"ollama": [
"aiohttp",
"aiosqlite",
"blobfile",
@ -145,10 +212,10 @@
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"matplotlib",
"nltk",
"numpy",
"ollama",
"pandas",
"pillow",
"psycopg2-binary",
@ -190,100 +257,6 @@
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"meta-reference-gpu": [
"accelerate",
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"fairscale",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"lm-format-enforcer",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"torch",
"torchvision",
"tqdm",
"transformers",
"uvicorn",
"zmq",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"meta-reference-quantized-gpu": [
"accelerate",
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"fairscale",
"faiss-cpu",
"fastapi",
"fbgemm-gpu",
"fire",
"httpx",
"lm-format-enforcer",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"torch",
"torchao==0.5.0",
"torchvision",
"tqdm",
"transformers",
"uvicorn",
"zmq",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"ollama": [
"aiohttp",
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"ollama",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"hf-endpoint": [
"aiohttp",
"aiosqlite",
@ -311,5 +284,58 @@
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"fireworks": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"cerebras": [
"aiosqlite",
"blobfile",
"cerebras_cloud_sdk",
"chardet",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
]
}

View file

@ -66,121 +66,247 @@ llama stack build --list-templates
```
```
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| Template Name | Providers | Description |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| hf-serverless | { | Like local, but use Hugging Face Inference API (serverless) for running LLM |
| | "inference": "remote::hf::serverless", | inference. |
| | "memory": "meta-reference", | See https://hf.co/docs/api-inference. |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| together | { | Use Together.ai for running LLM inference |
| | "inference": "remote::together", | |
| | "memory": [ | |
| | "meta-reference", | |
| | "remote::weaviate" | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| tgi | { | Use (an external) TGI server for running LLM inference |
| | "inference": [ | |
| | "remote::tgi" | |
| | ], | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| fireworks | { | Use Fireworks.ai for running LLM inference |
| | "inference": "remote::fireworks", | |
| | "memory": [ | |
| | "meta-reference", | |
| | "remote::weaviate", | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| databricks | { | Use Databricks for running LLM inference |
| | "inference": "remote::databricks", | |
| | "memory": "meta-reference", | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| vllm | { | Like local, but use vLLM for running LLM inference |
| | "inference": "vllm", | |
| | "memory": "meta-reference", | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| tgi | { | Use TGI for running LLM inference |
| | "inference": "remote::tgi", | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| remote-vllm | { | Use (an external) vLLM server for running LLM inference |
| | "inference": [ | |
| | "remote::vllm" | |
| | ], | |
| | "memory": [ | |
| | "meta-reference", | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| bedrock | { | Use Amazon Bedrock APIs. |
| | "inference": "remote::bedrock", | |
| | "memory": "meta-reference", | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| meta-reference-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs |
| | "inference": "meta-reference", | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| vllm-gpu | { | Use a built-in vLLM engine for running LLM inference |
| | "inference": [ | |
| | "inline::vllm" | |
| | ], | |
| | "memory": [ | |
| | "meta-reference", | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| meta-reference-quantized-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs |
| | "inference": "meta-reference-quantized", | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| meta-reference-quantized-gpu | { | Use Meta Reference with fp8, int4 quantization for running LLM inference |
| | "inference": [ | |
| | "inline::meta-reference-quantized" | |
| | ], | |
| | "memory": [ | |
| | "meta-reference", | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| ollama | { | Use ollama for running LLM inference |
| | "inference": "remote::ollama", | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| meta-reference-gpu | { | Use Meta Reference for running LLM inference |
| | "inference": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "memory": [ | |
| | "meta-reference", | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
| hf-endpoint | { | Like local, but use Hugging Face Inference Endpoints for running LLM inference. |
| | "inference": "remote::hf::endpoint", | See https://hf.co/docs/api-endpoints. |
| | "memory": "meta-reference", | |
| | "safety": "meta-reference", | |
| | "agents": "meta-reference", | |
| | "telemetry": "meta-reference" | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| hf-serverless | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference |
| | "inference": [ | |
| | "remote::hf::serverless" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| together | { | Use Together.AI for running LLM inference |
| | "inference": [ | |
| | "remote::together" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| ollama | { | Use (an external) Ollama server for running LLM inference |
| | "inference": [ | |
| | "remote::ollama" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| bedrock | { | Use AWS Bedrock for running LLM inference and safety |
| | "inference": [ | |
| | "remote::bedrock" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "remote::bedrock" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| hf-endpoint | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference |
| | "inference": [ | |
| | "remote::hf::endpoint" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| fireworks | { | Use Fireworks.AI for running LLM inference |
| | "inference": [ | |
| | "remote::fireworks" | |
| | ], | |
| | "memory": [ | |
| | "inline::faiss", | |
| | "remote::chromadb", | |
| | "remote::pgvector" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
| cerebras | { | Use Cerebras for running LLM inference |
| | "inference": [ | |
| | "remote::cerebras" | |
| | ], | |
| | "safety": [ | |
| | "inline::llama-guard" | |
| | ], | |
| | "memory": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "agents": [ | |
| | "inline::meta-reference" | |
| | ], | |
| | "telemetry": [ | |
| | "inline::meta-reference" | |
| | ] | |
| | } | |
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
```
You may then pick a template to build your distribution with providers fitted to your liking.

View file

@ -0,0 +1,61 @@
# Cerebras Distribution
The `llamastack/distribution-cerebras` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| inference | `remote::cerebras` |
| memory | `inline::meta-reference` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
### Models
The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct (llama3.1-8b)`
- `meta-llama/Llama-3.1-70B-Instruct (llama3.1-70b)`
### Prerequisite: API Keys
Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/).
## Running Llama Stack with Cerebras
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-cerebras \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
```
### Via Conda
```bash
llama stack build --template cerebras --image-type conda
llama stack run ./run.yaml \
--port 5001 \
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
```

View file

@ -45,6 +45,7 @@ Llama Stack already has a number of "adapters" available for some popular Infere
| **API Provider** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
| :----: | :----: | :----: | :----: | :----: | :----: | :----: |
| Meta Reference | Single Node | Y | Y | Y | Y | Y |
| Cerebras | Single Node | | Y | | | |
| Fireworks | Hosted | Y | Y | Y | | |
| AWS Bedrock | Hosted | | Y | | Y | |
| Together | Hosted | Y | Y | | Y | |

View file

@ -61,6 +61,17 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import CerebrasImplConfig
async def get_adapter_impl(config: CerebrasImplConfig, _deps):
from .cerebras import CerebrasInferenceAdapter
assert isinstance(
config, CerebrasImplConfig
), f"Unexpected config type: {type(config)}"
impl = CerebrasInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from cerebras.cloud.sdk import AsyncCerebras
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_models.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import CerebrasImplConfig
model_aliases = [
build_model_alias(
"llama3.1-8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"llama3.1-70b",
CoreModelId.llama3_1_70b_instruct.value,
),
]
class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=model_aliases,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = AsyncCerebras(
base_url=self.config.base_url, api_key=self.config.api_key
)
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(
request,
)
else:
return await self._nonstream_completion(request)
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
params = self._get_params(request)
r = await self.client.completions.create(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: CompletionRequest
) -> CompletionResponse:
params = self._get_params(request)
r = await self.client.completions.create(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: CompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
if request.sampling_params and request.sampling_params.top_k:
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if type(request) == ChatCompletionRequest:
prompt = chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
elif type(request) == CompletionRequest:
prompt = completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")
return {
"model": request.model,
"prompt": prompt,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
DEFAULT_BASE_URL = "https://api.cerebras.ai"
@json_schema_type
class CerebrasImplConfig(BaseModel):
base_url: str = Field(
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",
)
api_key: Optional[str] = Field(
default=os.environ.get("CEREBRAS_API_KEY"),
description="Cerebras API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"base_url": DEFAULT_BASE_URL,
"api_key": "${env.CEREBRAS_API_KEY}",
}

View file

@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.meta_reference import (
)
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
@ -64,6 +65,21 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
)
@pytest.fixture(scope="session")
def inference_cerebras() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="cerebras",
provider_type="remote::cerebras",
config=CerebrasImplConfig(
api_key=get_env_or_fail("CEREBRAS_API_KEY"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_ollama(inference_model) -> ProviderFixture:
inference_model = (
@ -206,6 +222,7 @@ INFERENCE_FIXTURES = [
"vllm_remote",
"remote",
"bedrock",
"cerebras",
"nvidia",
"tgi",
]

View file

@ -94,6 +94,7 @@ class TestInference:
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::cerebras",
):
pytest.skip("Other inference providers don't support completion() yet")
@ -139,6 +140,7 @@ class TestInference:
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::cerebras",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .cerebras import get_distribution_template # noqa: F401

View file

@ -0,0 +1,17 @@
version: '2'
name: cerebras
distribution_spec:
description: Use Cerebras for running LLM inference
docker_image: null
providers:
inference:
- remote::cerebras
safety:
- inline::llama-guard
memory:
- inline::meta-reference
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
image_type: conda

View file

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_models.sku_list import all_registered_models
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::cerebras"],
"safety": ["inline::llama-guard"],
"memory": ["inline::meta-reference"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
}
inference_provider = Provider(
provider_id="cerebras",
provider_type="remote::cerebras",
config=CerebrasImplConfig.sample_run_config(),
)
core_model_to_hf_repo = {
m.descriptor(): m.huggingface_repo for m in all_registered_models()
}
default_models = [
ModelInput(
model_id=core_model_to_hf_repo[m.llama_model],
provider_model_id=m.provider_model_id,
)
for m in model_aliases
]
return DistributionTemplate(
name="cerebras",
distro_type="self_hosted",
description="Use Cerebras for running LLM inference",
docker_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
default_models=default_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
},
default_models=default_models,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
),
},
run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"CEREBRAS_API_KEY": (
"",
"Cerebras API Key",
),
},
)

View file

@ -0,0 +1,60 @@
# Cerebras Distribution
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} ({{ model.provider_model_id }})`
{% endfor %}
{% endif %}
### Prerequisite: API Keys
Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/).
## Running Llama Stack with Cerebras
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-{{ name }} \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
```
### Via Conda
```bash
llama stack build --template cerebras --image-type conda
llama stack run ./run.yaml \
--port 5001 \
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
```

View file

@ -0,0 +1,63 @@
version: '2'
image_name: cerebras
docker_image: null
conda_env: cerebras
apis:
- agents
- inference
- memory
- safety
- telemetry
providers:
inference:
- provider_id: cerebras
provider_type: remote::cerebras
config:
base_url: https://api.cerebras.ai
api_key: ${env.CEREBRAS_API_KEY}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config: {}
memory:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/faiss_store.db
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db
models:
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: null
provider_model_id: llama3.1-8b
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: null
provider_model_id: llama3.1-70b
shields:
- params: null
shield_id: meta-llama/Llama-Guard-3-8B
provider_id: null
provider_shield_id: null
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []