mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Cerebras Inference Integration (#265)
Adding Cerebras Inference as an API provider. ## Testing ### Conda ``` $ llama stack build --template cerebras --image-type conda $ llama stack run ~/.llama/distributions/llamastack-cerebras/cerebras-run.yaml ... Listening on ['::', '0.0.0.0']:5000 INFO: Started server process [12443] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://['::', '0.0.0.0']:5000 (Press CTRL+C to quit) ``` ### Chat Completion ``` $ curl --location 'http://localhost:5000/alpha/inference/chat-completion' --header 'Content-Type: application/json' --data '{ "model_id": "meta-llama/Llama-3.1-8B-Instruct", "messages": [ { "role": "user", "content": "What is the temperature in Seattle right now?" } ], "stream": false, "sampling_params": { "strategy": "top_p", "temperature": 0.5, "max_tokens": 100 }, "tool_choice": "auto", "tool_prompt_format": "json", "tools": [ { "tool_name": "getTemperature", "description": "Gets the current temperature of a location.", "parameters": { "location": { "param_type": "string", "description": "The name of the place to get the temperature from in degress celsius.", "required": true } } } ] }' ``` #### Non-Streaming Response ``` { "completion_message": { "role": "assistant", "content": "", "stop_reason": "end_of_message", "tool_calls": [ { "call_id": "6f42fdcc-6cbb-46ad-a17b-5d20ac64b678", "tool_name": "getTemperature", "arguments": { "location": "Seattle" } } ] }, "logprobs": null } ``` #### Streaming Response ``` data: {"event":{"event_type":"start","delta":"","logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"","parse_status":"started"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"{\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"type","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"function","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"name","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"get","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"Temperature","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"parameters","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" {\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"location","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"Seattle","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\"}}","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":{"call_id":"e742df1f-0ae9-40ad-a49e-18e5c905484f","tool_name":"getTemperature","arguments":{"location":"Seattle"}},"parse_status":"success"},"logprobs":null,"stop_reason":"end_of_message"}} data: {"event":{"event_type":"complete","delta":"","logprobs":null,"stop_reason":"end_of_message"}} ``` ### Completion ``` $ curl --location 'http://localhost:5000/alpha/inference/completion' --header 'Content-Type: application/json' --data '{ "model_id": "meta-llama/Llama-3.1-8B-Instruct", "content": "1,2,3,", "stream": true, "sampling_params": { "strategy": "top_p", "temperature": 0.5, "max_tokens": 10 }, "tool_choice": "auto", "tool_prompt_format": "json", "tools": [ { "tool_name": "getTemperature", "description": "Gets the current temperature of a location.", "parameters": { "location": { "param_type": "string", "description": "The name of the place to get the temperature from in degress celsius.", "required": true } } } ] }' ``` #### Non-Streaming Response ``` { "content": "4,5,6,7,8,", "stop_reason": "out_of_tokens", "logprobs": null } ``` #### Streaming Response ``` data: {"delta":"4","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"5","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"6","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"7","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"8","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"","stop_reason":null,"logprobs":null} data: {"delta":"","stop_reason":"out_of_tokens","logprobs":null} ``` ### Pre-Commit Checks ``` trim trailing whitespace.................................................Passed check python ast.........................................................Passed check for merge conflicts................................................Passed check for added large files..............................................Passed fix end of files.........................................................Passed Insert license in comments...............................................Passed flake8...................................................................Passed Format files with µfmt...................................................Passed ``` ### Testing with `test_inference.py` ``` $ export CEREBRAS_API_KEY=<insert API key here> $ pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py -m "cerebras and llama_8b" /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/lib/python3.12/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) =================================================== test session starts =================================================== platform linux -- Python 3.12.3, pytest-8.3.3, pluggy-1.5.0 -- /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/bin/python3.12 cachedir: .pytest_cache rootdir: /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack configfile: pyproject.toml plugins: anyio-4.6.2.post1, asyncio-0.24.0 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 128 items / 120 deselected / 8 selected llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_8b-cerebras] Resolved 4 providers inner-inference => cerebras models => __routing_table__ inference => __autorouted__ inspect => __builtin__ Models: meta-llama/Llama-3.1-8B-Instruct served by cerebras PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completions_structured_output[llama_8b-cerebras] SKIPPED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-cerebras] SKIPPED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_8b-cerebras] PASSED ================================ 6 passed, 2 skipped, 120 deselected, 6 warnings in 3.95s ================================= ``` I ran `python llama_stack/scripts/distro_codegen.py` to run codegen.
This commit is contained in:
parent
b6500974ec
commit
64c6df8392
19 changed files with 1018 additions and 292 deletions
|
@ -80,6 +80,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
|
|||
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
||||
| :----: | :----: | :----: | :----: | :----: | :----: | :----: |
|
||||
| Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Cerebras | Single Node | | :heavy_check_mark: | | | |
|
||||
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||
|
@ -95,6 +96,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
|
|||
|:----------------: |:------------------------------------------: |:-----------------------: |
|
||||
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
||||
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
|
||||
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) |
|
||||
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
|
||||
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) |
|
||||
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) |
|
||||
|
|
1
distributions/cerebras/build.yaml
Symbolic link
1
distributions/cerebras/build.yaml
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../llama_stack/templates/cerebras/build.yaml
|
16
distributions/cerebras/compose.yaml
Normal file
16
distributions/cerebras/compose.yaml
Normal file
|
@ -0,0 +1,16 @@
|
|||
services:
|
||||
llamastack:
|
||||
image: llamastack/distribution-cerebras
|
||||
network_mode: "host"
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
- ./run.yaml:/root/llamastack-run-cerebras.yaml
|
||||
ports:
|
||||
- "5000:5000"
|
||||
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-cerebras.yaml"
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
1
distributions/cerebras/run.yaml
Symbolic link
1
distributions/cerebras/run.yaml
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../llama_stack/templates/cerebras/run.yaml
|
|
@ -1,4 +1,152 @@
|
|||
{
|
||||
"tgi": [
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"huggingface_hub",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"remote-vllm": [
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"vllm-gpu": [
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"vllm",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"meta-reference-quantized-gpu": [
|
||||
"accelerate",
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"fairscale",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fbgemm-gpu",
|
||||
"fire",
|
||||
"httpx",
|
||||
"lm-format-enforcer",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"torch",
|
||||
"torchao==0.5.0",
|
||||
"torchvision",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"zmq",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"meta-reference-gpu": [
|
||||
"accelerate",
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"fairscale",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"lm-format-enforcer",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"zmq",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"hf-serverless": [
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
|
@ -54,88 +202,7 @@
|
|||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"vllm-gpu": [
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"vllm",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"remote-vllm": [
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"fireworks": [
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"fireworks-ai",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"tgi": [
|
||||
"ollama": [
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
|
@ -145,10 +212,10 @@
|
|||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"huggingface_hub",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"ollama",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
|
@ -190,100 +257,6 @@
|
|||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"meta-reference-gpu": [
|
||||
"accelerate",
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"fairscale",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"lm-format-enforcer",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"zmq",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"meta-reference-quantized-gpu": [
|
||||
"accelerate",
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"fairscale",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fbgemm-gpu",
|
||||
"fire",
|
||||
"httpx",
|
||||
"lm-format-enforcer",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"torch",
|
||||
"torchao==0.5.0",
|
||||
"torchvision",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"zmq",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"ollama": [
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"ollama",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"hf-endpoint": [
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
|
@ -311,5 +284,58 @@
|
|||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"fireworks": [
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"fireworks-ai",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"cerebras": [
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"cerebras_cloud_sdk",
|
||||
"chardet",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||
]
|
||||
}
|
||||
|
|
|
@ -66,121 +66,247 @@ llama stack build --list-templates
|
|||
```
|
||||
|
||||
```
|
||||
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
|
||||
| Template Name | Providers | Description |
|
||||
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
|
||||
| hf-serverless | { | Like local, but use Hugging Face Inference API (serverless) for running LLM |
|
||||
| | "inference": "remote::hf::serverless", | inference. |
|
||||
| | "memory": "meta-reference", | See https://hf.co/docs/api-inference. |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agents": "meta-reference", | |
|
||||
| | "telemetry": "meta-reference" | |
|
||||
| | } | |
|
||||
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
|
||||
| together | { | Use Together.ai for running LLM inference |
|
||||
| | "inference": "remote::together", | |
|
||||
| | "memory": [ | |
|
||||
| | "meta-reference", | |
|
||||
| | "remote::weaviate" | |
|
||||
| | ], | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agents": "meta-reference", | |
|
||||
| | "telemetry": "meta-reference" | |
|
||||
| | } | |
|
||||
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
|
||||
| fireworks | { | Use Fireworks.ai for running LLM inference |
|
||||
| | "inference": "remote::fireworks", | |
|
||||
| | "memory": [ | |
|
||||
| | "meta-reference", | |
|
||||
| | "remote::weaviate", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agents": "meta-reference", | |
|
||||
| | "telemetry": "meta-reference" | |
|
||||
| | } | |
|
||||
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
|
||||
| databricks | { | Use Databricks for running LLM inference |
|
||||
| | "inference": "remote::databricks", | |
|
||||
| | "memory": "meta-reference", | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agents": "meta-reference", | |
|
||||
| | "telemetry": "meta-reference" | |
|
||||
| | } | |
|
||||
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
|
||||
| vllm | { | Like local, but use vLLM for running LLM inference |
|
||||
| | "inference": "vllm", | |
|
||||
| | "memory": "meta-reference", | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agents": "meta-reference", | |
|
||||
| | "telemetry": "meta-reference" | |
|
||||
| | } | |
|
||||
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
|
||||
| tgi | { | Use TGI for running LLM inference |
|
||||
| | "inference": "remote::tgi", | |
|
||||
| | "memory": [ | |
|
||||
| | "meta-reference", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agents": "meta-reference", | |
|
||||
| | "telemetry": "meta-reference" | |
|
||||
| | } | |
|
||||
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
|
||||
| bedrock | { | Use Amazon Bedrock APIs. |
|
||||
| | "inference": "remote::bedrock", | |
|
||||
| | "memory": "meta-reference", | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agents": "meta-reference", | |
|
||||
| | "telemetry": "meta-reference" | |
|
||||
| | } | |
|
||||
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
|
||||
| meta-reference-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs |
|
||||
| | "inference": "meta-reference", | |
|
||||
| | "memory": [ | |
|
||||
| | "meta-reference", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agents": "meta-reference", | |
|
||||
| | "telemetry": "meta-reference" | |
|
||||
| | } | |
|
||||
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
|
||||
| meta-reference-quantized-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs |
|
||||
| | "inference": "meta-reference-quantized", | |
|
||||
| | "memory": [ | |
|
||||
| | "meta-reference", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agents": "meta-reference", | |
|
||||
| | "telemetry": "meta-reference" | |
|
||||
| | } | |
|
||||
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
|
||||
| ollama | { | Use ollama for running LLM inference |
|
||||
| | "inference": "remote::ollama", | |
|
||||
| | "memory": [ | |
|
||||
| | "meta-reference", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agents": "meta-reference", | |
|
||||
| | "telemetry": "meta-reference" | |
|
||||
| | } | |
|
||||
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
|
||||
| hf-endpoint | { | Like local, but use Hugging Face Inference Endpoints for running LLM inference. |
|
||||
| | "inference": "remote::hf::endpoint", | See https://hf.co/docs/api-endpoints. |
|
||||
| | "memory": "meta-reference", | |
|
||||
| | "safety": "meta-reference", | |
|
||||
| | "agents": "meta-reference", | |
|
||||
| | "telemetry": "meta-reference" | |
|
||||
| | } | |
|
||||
+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
| Template Name | Providers | Description |
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
| tgi | { | Use (an external) TGI server for running LLM inference |
|
||||
| | "inference": [ | |
|
||||
| | "remote::tgi" | |
|
||||
| | ], | |
|
||||
| | "memory": [ | |
|
||||
| | "inline::faiss", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": [ | |
|
||||
| | "inline::llama-guard" | |
|
||||
| | ], | |
|
||||
| | "agents": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "telemetry": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ] | |
|
||||
| | } | |
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
| remote-vllm | { | Use (an external) vLLM server for running LLM inference |
|
||||
| | "inference": [ | |
|
||||
| | "remote::vllm" | |
|
||||
| | ], | |
|
||||
| | "memory": [ | |
|
||||
| | "inline::faiss", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": [ | |
|
||||
| | "inline::llama-guard" | |
|
||||
| | ], | |
|
||||
| | "agents": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "telemetry": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ] | |
|
||||
| | } | |
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
| vllm-gpu | { | Use a built-in vLLM engine for running LLM inference |
|
||||
| | "inference": [ | |
|
||||
| | "inline::vllm" | |
|
||||
| | ], | |
|
||||
| | "memory": [ | |
|
||||
| | "inline::faiss", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": [ | |
|
||||
| | "inline::llama-guard" | |
|
||||
| | ], | |
|
||||
| | "agents": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "telemetry": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ] | |
|
||||
| | } | |
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
| meta-reference-quantized-gpu | { | Use Meta Reference with fp8, int4 quantization for running LLM inference |
|
||||
| | "inference": [ | |
|
||||
| | "inline::meta-reference-quantized" | |
|
||||
| | ], | |
|
||||
| | "memory": [ | |
|
||||
| | "inline::faiss", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": [ | |
|
||||
| | "inline::llama-guard" | |
|
||||
| | ], | |
|
||||
| | "agents": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "telemetry": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ] | |
|
||||
| | } | |
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
| meta-reference-gpu | { | Use Meta Reference for running LLM inference |
|
||||
| | "inference": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "memory": [ | |
|
||||
| | "inline::faiss", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": [ | |
|
||||
| | "inline::llama-guard" | |
|
||||
| | ], | |
|
||||
| | "agents": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "telemetry": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ] | |
|
||||
| | } | |
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
| hf-serverless | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference |
|
||||
| | "inference": [ | |
|
||||
| | "remote::hf::serverless" | |
|
||||
| | ], | |
|
||||
| | "memory": [ | |
|
||||
| | "inline::faiss", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": [ | |
|
||||
| | "inline::llama-guard" | |
|
||||
| | ], | |
|
||||
| | "agents": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "telemetry": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ] | |
|
||||
| | } | |
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
| together | { | Use Together.AI for running LLM inference |
|
||||
| | "inference": [ | |
|
||||
| | "remote::together" | |
|
||||
| | ], | |
|
||||
| | "memory": [ | |
|
||||
| | "inline::faiss", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": [ | |
|
||||
| | "inline::llama-guard" | |
|
||||
| | ], | |
|
||||
| | "agents": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "telemetry": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ] | |
|
||||
| | } | |
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
| ollama | { | Use (an external) Ollama server for running LLM inference |
|
||||
| | "inference": [ | |
|
||||
| | "remote::ollama" | |
|
||||
| | ], | |
|
||||
| | "memory": [ | |
|
||||
| | "inline::faiss", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": [ | |
|
||||
| | "inline::llama-guard" | |
|
||||
| | ], | |
|
||||
| | "agents": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "telemetry": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ] | |
|
||||
| | } | |
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
| bedrock | { | Use AWS Bedrock for running LLM inference and safety |
|
||||
| | "inference": [ | |
|
||||
| | "remote::bedrock" | |
|
||||
| | ], | |
|
||||
| | "memory": [ | |
|
||||
| | "inline::faiss", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": [ | |
|
||||
| | "remote::bedrock" | |
|
||||
| | ], | |
|
||||
| | "agents": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "telemetry": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ] | |
|
||||
| | } | |
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
| hf-endpoint | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference |
|
||||
| | "inference": [ | |
|
||||
| | "remote::hf::endpoint" | |
|
||||
| | ], | |
|
||||
| | "memory": [ | |
|
||||
| | "inline::faiss", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": [ | |
|
||||
| | "inline::llama-guard" | |
|
||||
| | ], | |
|
||||
| | "agents": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "telemetry": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ] | |
|
||||
| | } | |
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
| fireworks | { | Use Fireworks.AI for running LLM inference |
|
||||
| | "inference": [ | |
|
||||
| | "remote::fireworks" | |
|
||||
| | ], | |
|
||||
| | "memory": [ | |
|
||||
| | "inline::faiss", | |
|
||||
| | "remote::chromadb", | |
|
||||
| | "remote::pgvector" | |
|
||||
| | ], | |
|
||||
| | "safety": [ | |
|
||||
| | "inline::llama-guard" | |
|
||||
| | ], | |
|
||||
| | "agents": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "telemetry": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ] | |
|
||||
| | } | |
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
| cerebras | { | Use Cerebras for running LLM inference |
|
||||
| | "inference": [ | |
|
||||
| | "remote::cerebras" | |
|
||||
| | ], | |
|
||||
| | "safety": [ | |
|
||||
| | "inline::llama-guard" | |
|
||||
| | ], | |
|
||||
| | "memory": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "agents": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ], | |
|
||||
| | "telemetry": [ | |
|
||||
| | "inline::meta-reference" | |
|
||||
| | ] | |
|
||||
| | } | |
|
||||
+------------------------------+----------------------------------------+-----------------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
You may then pick a template to build your distribution with providers fitted to your liking.
|
||||
|
|
61
docs/source/distributions/self_hosted_distro/cerebras.md
Normal file
61
docs/source/distributions/self_hosted_distro/cerebras.md
Normal file
|
@ -0,0 +1,61 @@
|
|||
# Cerebras Distribution
|
||||
|
||||
The `llamastack/distribution-cerebras` distribution consists of the following provider configurations.
|
||||
|
||||
| API | Provider(s) |
|
||||
|-----|-------------|
|
||||
| agents | `inline::meta-reference` |
|
||||
| inference | `remote::cerebras` |
|
||||
| memory | `inline::meta-reference` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
|
||||
|
||||
### Models
|
||||
|
||||
The following models are available by default:
|
||||
|
||||
- `meta-llama/Llama-3.1-8B-Instruct (llama3.1-8b)`
|
||||
- `meta-llama/Llama-3.1-70B-Instruct (llama3.1-70b)`
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
|
||||
Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/).
|
||||
|
||||
|
||||
## Running Llama Stack with Cerebras
|
||||
|
||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||
|
||||
### Via Docker
|
||||
|
||||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
docker run \
|
||||
-it \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-cerebras \
|
||||
--yaml-config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||
```
|
||||
|
||||
### Via Conda
|
||||
|
||||
```bash
|
||||
llama stack build --template cerebras --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 5001 \
|
||||
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||
```
|
|
@ -45,6 +45,7 @@ Llama Stack already has a number of "adapters" available for some popular Infere
|
|||
| **API Provider** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|
||||
| :----: | :----: | :----: | :----: | :----: | :----: | :----: |
|
||||
| Meta Reference | Single Node | Y | Y | Y | Y | Y |
|
||||
| Cerebras | Single Node | | Y | | | |
|
||||
| Fireworks | Hosted | Y | Y | Y | | |
|
||||
| AWS Bedrock | Hosted | | Y | | Y | |
|
||||
| Together | Hosted | Y | Y | | Y | |
|
||||
|
|
|
@ -61,6 +61,17 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="cerebras",
|
||||
pip_packages=[
|
||||
"cerebras_cloud_sdk",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
21
llama_stack/providers/remote/inference/cerebras/__init__.py
Normal file
21
llama_stack/providers/remote/inference/cerebras/__init__.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: CerebrasImplConfig, _deps):
|
||||
from .cerebras import CerebrasInferenceAdapter
|
||||
|
||||
assert isinstance(
|
||||
config, CerebrasImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = CerebrasInferenceAdapter(config)
|
||||
|
||||
await impl.initialize()
|
||||
|
||||
return impl
|
191
llama_stack/providers/remote/inference/cerebras/cerebras.py
Normal file
191
llama_stack/providers/remote/inference/cerebras/cerebras.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
|
||||
|
||||
model_aliases = [
|
||||
build_model_alias(
|
||||
"llama3.1-8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_alias(
|
||||
"llama3.1-70b",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_aliases=model_aliases,
|
||||
)
|
||||
self.config = config
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
self.client = AsyncCerebras(
|
||||
base_url=self.config.base_url, api_key=self.config.api_key
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_completion(
|
||||
request,
|
||||
)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> CompletionResponse:
|
||||
params = self._get_params(request)
|
||||
|
||||
r = await self.client.completions.create(**params)
|
||||
|
||||
return process_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
yield chunk
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> CompletionResponse:
|
||||
params = self._get_params(request)
|
||||
|
||||
r = await self.client.completions.create(**params)
|
||||
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: CompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||
) -> dict:
|
||||
if request.sampling_params and request.sampling_params.top_k:
|
||||
raise ValueError("`top_k` not supported by Cerebras")
|
||||
|
||||
prompt = ""
|
||||
if type(request) == ChatCompletionRequest:
|
||||
prompt = chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model), self.formatter
|
||||
)
|
||||
elif type(request) == CompletionRequest:
|
||||
prompt = completion_request_to_prompt(request, self.formatter)
|
||||
else:
|
||||
raise ValueError(f"Unknown request type {type(request)}")
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
"prompt": prompt,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
32
llama_stack/providers/remote/inference/cerebras/config.py
Normal file
32
llama_stack/providers/remote/inference/cerebras/config.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasImplConfig(BaseModel):
|
||||
base_url: str = Field(
|
||||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=os.environ.get("CEREBRAS_API_KEY"),
|
||||
description="Cerebras API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"base_url": DEFAULT_BASE_URL,
|
||||
"api_key": "${env.CEREBRAS_API_KEY}",
|
||||
}
|
|
@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.meta_reference import (
|
|||
)
|
||||
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
||||
|
||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
|
@ -64,6 +65,21 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_cerebras() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="cerebras",
|
||||
provider_type="remote::cerebras",
|
||||
config=CerebrasImplConfig(
|
||||
api_key=get_env_or_fail("CEREBRAS_API_KEY"),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_ollama(inference_model) -> ProviderFixture:
|
||||
inference_model = (
|
||||
|
@ -206,6 +222,7 @@ INFERENCE_FIXTURES = [
|
|||
"vllm_remote",
|
||||
"remote",
|
||||
"bedrock",
|
||||
"cerebras",
|
||||
"nvidia",
|
||||
"tgi",
|
||||
]
|
||||
|
|
|
@ -94,6 +94,7 @@ class TestInference:
|
|||
"remote::tgi",
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
"remote::cerebras",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
|
@ -139,6 +140,7 @@ class TestInference:
|
|||
"remote::tgi",
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
"remote::cerebras",
|
||||
):
|
||||
pytest.skip(
|
||||
"Other inference providers don't support structured output in completions yet"
|
||||
|
|
7
llama_stack/templates/cerebras/__init__.py
Normal file
7
llama_stack/templates/cerebras/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .cerebras import get_distribution_template # noqa: F401
|
17
llama_stack/templates/cerebras/build.yaml
Normal file
17
llama_stack/templates/cerebras/build.yaml
Normal file
|
@ -0,0 +1,17 @@
|
|||
version: '2'
|
||||
name: cerebras
|
||||
distribution_spec:
|
||||
description: Use Cerebras for running LLM inference
|
||||
docker_image: null
|
||||
providers:
|
||||
inference:
|
||||
- remote::cerebras
|
||||
safety:
|
||||
- inline::llama-guard
|
||||
memory:
|
||||
- inline::meta-reference
|
||||
agents:
|
||||
- inline::meta-reference
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
image_type: conda
|
71
llama_stack/templates/cerebras/cerebras.py
Normal file
71
llama_stack/templates/cerebras/cerebras.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||
from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases
|
||||
|
||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["remote::cerebras"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"memory": ["inline::meta-reference"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
}
|
||||
|
||||
inference_provider = Provider(
|
||||
provider_id="cerebras",
|
||||
provider_type="remote::cerebras",
|
||||
config=CerebrasImplConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
core_model_to_hf_repo = {
|
||||
m.descriptor(): m.huggingface_repo for m in all_registered_models()
|
||||
}
|
||||
default_models = [
|
||||
ModelInput(
|
||||
model_id=core_model_to_hf_repo[m.llama_model],
|
||||
provider_model_id=m.provider_model_id,
|
||||
)
|
||||
for m in model_aliases
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name="cerebras",
|
||||
distro_type="self_hosted",
|
||||
description="Use Cerebras for running LLM inference",
|
||||
docker_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
default_models=default_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
},
|
||||
default_models=default_models,
|
||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"LLAMASTACK_PORT": (
|
||||
"5001",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"CEREBRAS_API_KEY": (
|
||||
"",
|
||||
"Cerebras API Key",
|
||||
),
|
||||
},
|
||||
)
|
60
llama_stack/templates/cerebras/doc_template.md
Normal file
60
llama_stack/templates/cerebras/doc_template.md
Normal file
|
@ -0,0 +1,60 @@
|
|||
# Cerebras Distribution
|
||||
|
||||
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||
|
||||
{{ providers_table }}
|
||||
|
||||
{% if run_config_env_vars %}
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if default_models %}
|
||||
### Models
|
||||
|
||||
The following models are available by default:
|
||||
|
||||
{% for model in default_models %}
|
||||
- `{{ model.model_id }} ({{ model.provider_model_id }})`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
|
||||
Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/).
|
||||
|
||||
|
||||
## Running Llama Stack with Cerebras
|
||||
|
||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||
|
||||
### Via Docker
|
||||
|
||||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
docker run \
|
||||
-it \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--yaml-config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||
```
|
||||
|
||||
### Via Conda
|
||||
|
||||
```bash
|
||||
llama stack build --template cerebras --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 5001 \
|
||||
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||
```
|
63
llama_stack/templates/cerebras/run.yaml
Normal file
63
llama_stack/templates/cerebras/run.yaml
Normal file
|
@ -0,0 +1,63 @@
|
|||
version: '2'
|
||||
image_name: cerebras
|
||||
docker_image: null
|
||||
conda_env: cerebras
|
||||
apis:
|
||||
- agents
|
||||
- inference
|
||||
- memory
|
||||
- safety
|
||||
- telemetry
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: cerebras
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
api_key: ${env.CEREBRAS_API_KEY}
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
memory:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/faiss_store.db
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: null
|
||||
provider_model_id: llama3.1-8b
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||
provider_id: null
|
||||
provider_model_id: llama3.1-70b
|
||||
shields:
|
||||
- params: null
|
||||
shield_id: meta-llama/Llama-Guard-3-8B
|
||||
provider_id: null
|
||||
provider_shield_id: null
|
||||
memory_banks: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
eval_tasks: []
|
Loading…
Add table
Add a link
Reference in a new issue