From 64c6df8392c8ceea321375bca12af2b025f6693e Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Wed, 4 Dec 2024 00:15:32 -0500 Subject: [PATCH] Cerebras Inference Integration (#265) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adding Cerebras Inference as an API provider. ## Testing ### Conda ``` $ llama stack build --template cerebras --image-type conda $ llama stack run ~/.llama/distributions/llamastack-cerebras/cerebras-run.yaml ... Listening on ['::', '0.0.0.0']:5000 INFO: Started server process [12443] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://['::', '0.0.0.0']:5000 (Press CTRL+C to quit) ``` ### Chat Completion ``` $ curl --location 'http://localhost:5000/alpha/inference/chat-completion' --header 'Content-Type: application/json' --data '{ "model_id": "meta-llama/Llama-3.1-8B-Instruct", "messages": [ { "role": "user", "content": "What is the temperature in Seattle right now?" } ], "stream": false, "sampling_params": { "strategy": "top_p", "temperature": 0.5, "max_tokens": 100 }, "tool_choice": "auto", "tool_prompt_format": "json", "tools": [ { "tool_name": "getTemperature", "description": "Gets the current temperature of a location.", "parameters": { "location": { "param_type": "string", "description": "The name of the place to get the temperature from in degress celsius.", "required": true } } } ] }' ``` #### Non-Streaming Response ``` { "completion_message": { "role": "assistant", "content": "", "stop_reason": "end_of_message", "tool_calls": [ { "call_id": "6f42fdcc-6cbb-46ad-a17b-5d20ac64b678", "tool_name": "getTemperature", "arguments": { "location": "Seattle" } } ] }, "logprobs": null } ``` #### Streaming Response ``` data: {"event":{"event_type":"start","delta":"","logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"","parse_status":"started"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"{\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"type","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"function","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"name","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"get","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"Temperature","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"parameters","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" {\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"location","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"Seattle","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\"}}","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":{"call_id":"e742df1f-0ae9-40ad-a49e-18e5c905484f","tool_name":"getTemperature","arguments":{"location":"Seattle"}},"parse_status":"success"},"logprobs":null,"stop_reason":"end_of_message"}} data: {"event":{"event_type":"complete","delta":"","logprobs":null,"stop_reason":"end_of_message"}} ``` ### Completion ``` $ curl --location 'http://localhost:5000/alpha/inference/completion' --header 'Content-Type: application/json' --data '{ "model_id": "meta-llama/Llama-3.1-8B-Instruct", "content": "1,2,3,", "stream": true, "sampling_params": { "strategy": "top_p", "temperature": 0.5, "max_tokens": 10 }, "tool_choice": "auto", "tool_prompt_format": "json", "tools": [ { "tool_name": "getTemperature", "description": "Gets the current temperature of a location.", "parameters": { "location": { "param_type": "string", "description": "The name of the place to get the temperature from in degress celsius.", "required": true } } } ] }' ``` #### Non-Streaming Response ``` { "content": "4,5,6,7,8,", "stop_reason": "out_of_tokens", "logprobs": null } ``` #### Streaming Response ``` data: {"delta":"4","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"5","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"6","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"7","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"8","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"","stop_reason":null,"logprobs":null} data: {"delta":"","stop_reason":"out_of_tokens","logprobs":null} ``` ### Pre-Commit Checks ``` trim trailing whitespace.................................................Passed check python ast.........................................................Passed check for merge conflicts................................................Passed check for added large files..............................................Passed fix end of files.........................................................Passed Insert license in comments...............................................Passed flake8...................................................................Passed Format files with µfmt...................................................Passed ``` ### Testing with `test_inference.py` ``` $ export CEREBRAS_API_KEY= $ pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py -m "cerebras and llama_8b" /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/lib/python3.12/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) =================================================== test session starts =================================================== platform linux -- Python 3.12.3, pytest-8.3.3, pluggy-1.5.0 -- /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/bin/python3.12 cachedir: .pytest_cache rootdir: /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack configfile: pyproject.toml plugins: anyio-4.6.2.post1, asyncio-0.24.0 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 128 items / 120 deselected / 8 selected llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_8b-cerebras] Resolved 4 providers inner-inference => cerebras models => __routing_table__ inference => __autorouted__ inspect => __builtin__ Models: meta-llama/Llama-3.1-8B-Instruct served by cerebras PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completions_structured_output[llama_8b-cerebras] SKIPPED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-cerebras] SKIPPED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_8b-cerebras] PASSED ================================ 6 passed, 2 skipped, 120 deselected, 6 warnings in 3.95s ================================= ``` I ran `python llama_stack/scripts/distro_codegen.py` to run codegen. --- README.md | 2 + distributions/cerebras/build.yaml | 1 + distributions/cerebras/compose.yaml | 16 + distributions/cerebras/run.yaml | 1 + distributions/dependencies.json | 380 ++++++++++-------- docs/source/distributions/building_distro.md | 356 ++++++++++------ .../self_hosted_distro/cerebras.md | 61 +++ docs/source/index.md | 1 + llama_stack/providers/registry/inference.py | 11 + .../remote/inference/cerebras/__init__.py | 21 + .../remote/inference/cerebras/cerebras.py | 191 +++++++++ .../remote/inference/cerebras/config.py | 32 ++ .../providers/tests/inference/fixtures.py | 17 + .../tests/inference/test_text_inference.py | 2 + llama_stack/templates/cerebras/__init__.py | 7 + llama_stack/templates/cerebras/build.yaml | 17 + llama_stack/templates/cerebras/cerebras.py | 71 ++++ .../templates/cerebras/doc_template.md | 60 +++ llama_stack/templates/cerebras/run.yaml | 63 +++ 19 files changed, 1018 insertions(+), 292 deletions(-) create mode 120000 distributions/cerebras/build.yaml create mode 100644 distributions/cerebras/compose.yaml create mode 120000 distributions/cerebras/run.yaml create mode 100644 docs/source/distributions/self_hosted_distro/cerebras.md create mode 100644 llama_stack/providers/remote/inference/cerebras/__init__.py create mode 100644 llama_stack/providers/remote/inference/cerebras/cerebras.py create mode 100644 llama_stack/providers/remote/inference/cerebras/config.py create mode 100644 llama_stack/templates/cerebras/__init__.py create mode 100644 llama_stack/templates/cerebras/build.yaml create mode 100644 llama_stack/templates/cerebras/cerebras.py create mode 100644 llama_stack/templates/cerebras/doc_template.md create mode 100644 llama_stack/templates/cerebras/run.yaml diff --git a/README.md b/README.md index 8e57292c3..0dfb1306d 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ Additionally, we have designed every element of the Stack such that APIs as well | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | | Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| Cerebras | Single Node | | :heavy_check_mark: | | | | | Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | | AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | | | Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | @@ -95,6 +96,7 @@ Additionally, we have designed every element of the Stack such that APIs as well |:----------------: |:------------------------------------------: |:-----------------------: | | Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) | | Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | +| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) | | Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) | | TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) | | Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) | diff --git a/distributions/cerebras/build.yaml b/distributions/cerebras/build.yaml new file mode 120000 index 000000000..bccbbcf60 --- /dev/null +++ b/distributions/cerebras/build.yaml @@ -0,0 +1 @@ +../../llama_stack/templates/cerebras/build.yaml \ No newline at end of file diff --git a/distributions/cerebras/compose.yaml b/distributions/cerebras/compose.yaml new file mode 100644 index 000000000..f2e9a6f42 --- /dev/null +++ b/distributions/cerebras/compose.yaml @@ -0,0 +1,16 @@ +services: + llamastack: + image: llamastack/distribution-cerebras + network_mode: "host" + volumes: + - ~/.llama:/root/.llama + - ./run.yaml:/root/llamastack-run-cerebras.yaml + ports: + - "5000:5000" + entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-cerebras.yaml" + deploy: + restart_policy: + condition: on-failure + delay: 3s + max_attempts: 5 + window: 60s diff --git a/distributions/cerebras/run.yaml b/distributions/cerebras/run.yaml new file mode 120000 index 000000000..9f9d20b4b --- /dev/null +++ b/distributions/cerebras/run.yaml @@ -0,0 +1 @@ +../../llama_stack/templates/cerebras/run.yaml \ No newline at end of file diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 36426e862..80468cc73 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -1,4 +1,152 @@ { + "tgi": [ + "aiohttp", + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "huggingface_hub", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "remote-vllm": [ + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "openai", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "vllm-gpu": [ + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "vllm", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "meta-reference-quantized-gpu": [ + "accelerate", + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "fairscale", + "faiss-cpu", + "fastapi", + "fbgemm-gpu", + "fire", + "httpx", + "lm-format-enforcer", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "torch", + "torchao==0.5.0", + "torchvision", + "tqdm", + "transformers", + "uvicorn", + "zmq", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "meta-reference-gpu": [ + "accelerate", + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "fairscale", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "lm-format-enforcer", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "torch", + "torchvision", + "tqdm", + "transformers", + "uvicorn", + "zmq", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], "hf-serverless": [ "aiohttp", "aiosqlite", @@ -54,88 +202,7 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "vllm-gpu": [ - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "vllm", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "remote-vllm": [ - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "openai", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "fireworks": [ - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "fireworks-ai", - "httpx", - "matplotlib", - "nltk", - "numpy", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "tgi": [ + "ollama": [ "aiohttp", "aiosqlite", "blobfile", @@ -145,10 +212,10 @@ "fastapi", "fire", "httpx", - "huggingface_hub", "matplotlib", "nltk", "numpy", + "ollama", "pandas", "pillow", "psycopg2-binary", @@ -190,100 +257,6 @@ "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" ], - "meta-reference-gpu": [ - "accelerate", - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "fairscale", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "lm-format-enforcer", - "matplotlib", - "nltk", - "numpy", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "torch", - "torchvision", - "tqdm", - "transformers", - "uvicorn", - "zmq", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "meta-reference-quantized-gpu": [ - "accelerate", - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "fairscale", - "faiss-cpu", - "fastapi", - "fbgemm-gpu", - "fire", - "httpx", - "lm-format-enforcer", - "matplotlib", - "nltk", - "numpy", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "torch", - "torchao==0.5.0", - "torchvision", - "tqdm", - "transformers", - "uvicorn", - "zmq", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], - "ollama": [ - "aiohttp", - "aiosqlite", - "blobfile", - "chardet", - "chromadb-client", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "nltk", - "numpy", - "ollama", - "pandas", - "pillow", - "psycopg2-binary", - "pypdf", - "redis", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch --index-url https://download.pytorch.org/whl/cpu" - ], "hf-endpoint": [ "aiohttp", "aiosqlite", @@ -311,5 +284,58 @@ "uvicorn", "sentence-transformers --no-deps", "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "fireworks": [ + "aiosqlite", + "blobfile", + "chardet", + "chromadb-client", + "faiss-cpu", + "fastapi", + "fire", + "fireworks-ai", + "httpx", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" + ], + "cerebras": [ + "aiosqlite", + "blobfile", + "cerebras_cloud_sdk", + "chardet", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "nltk", + "numpy", + "pandas", + "pillow", + "psycopg2-binary", + "pypdf", + "redis", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch --index-url https://download.pytorch.org/whl/cpu" ] } diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md index a45d07ebf..67d39159c 100644 --- a/docs/source/distributions/building_distro.md +++ b/docs/source/distributions/building_distro.md @@ -66,121 +66,247 @@ llama stack build --list-templates ``` ``` -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| Template Name | Providers | Description | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| hf-serverless | { | Like local, but use Hugging Face Inference API (serverless) for running LLM | -| | "inference": "remote::hf::serverless", | inference. | -| | "memory": "meta-reference", | See https://hf.co/docs/api-inference. | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| together | { | Use Together.ai for running LLM inference | -| | "inference": "remote::together", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::weaviate" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| fireworks | { | Use Fireworks.ai for running LLM inference | -| | "inference": "remote::fireworks", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::weaviate", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| databricks | { | Use Databricks for running LLM inference | -| | "inference": "remote::databricks", | | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| vllm | { | Like local, but use vLLM for running LLM inference | -| | "inference": "vllm", | | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| tgi | { | Use TGI for running LLM inference | -| | "inference": "remote::tgi", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| bedrock | { | Use Amazon Bedrock APIs. | -| | "inference": "remote::bedrock", | | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| meta-reference-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs | -| | "inference": "meta-reference", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| meta-reference-quantized-gpu | { | Use code from `llama_stack` itself to serve all llama stack APIs | -| | "inference": "meta-reference-quantized", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| ollama | { | Use ollama for running LLM inference | -| | "inference": "remote::ollama", | | -| | "memory": [ | | -| | "meta-reference", | | -| | "remote::chromadb", | | -| | "remote::pgvector" | | -| | ], | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ -| hf-endpoint | { | Like local, but use Hugging Face Inference Endpoints for running LLM inference. | -| | "inference": "remote::hf::endpoint", | See https://hf.co/docs/api-endpoints. | -| | "memory": "meta-reference", | | -| | "safety": "meta-reference", | | -| | "agents": "meta-reference", | | -| | "telemetry": "meta-reference" | | -| | } | | -+------------------------------+--------------------------------------------+----------------------------------------------------------------------------------+ ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| Template Name | Providers | Description | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| tgi | { | Use (an external) TGI server for running LLM inference | +| | "inference": [ | | +| | "remote::tgi" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| remote-vllm | { | Use (an external) vLLM server for running LLM inference | +| | "inference": [ | | +| | "remote::vllm" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| vllm-gpu | { | Use a built-in vLLM engine for running LLM inference | +| | "inference": [ | | +| | "inline::vllm" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| meta-reference-quantized-gpu | { | Use Meta Reference with fp8, int4 quantization for running LLM inference | +| | "inference": [ | | +| | "inline::meta-reference-quantized" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| meta-reference-gpu | { | Use Meta Reference for running LLM inference | +| | "inference": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| hf-serverless | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference | +| | "inference": [ | | +| | "remote::hf::serverless" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| together | { | Use Together.AI for running LLM inference | +| | "inference": [ | | +| | "remote::together" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| ollama | { | Use (an external) Ollama server for running LLM inference | +| | "inference": [ | | +| | "remote::ollama" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| bedrock | { | Use AWS Bedrock for running LLM inference and safety | +| | "inference": [ | | +| | "remote::bedrock" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "remote::bedrock" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| hf-endpoint | { | Use (an external) Hugging Face Inference Endpoint for running LLM inference | +| | "inference": [ | | +| | "remote::hf::endpoint" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| fireworks | { | Use Fireworks.AI for running LLM inference | +| | "inference": [ | | +| | "remote::fireworks" | | +| | ], | | +| | "memory": [ | | +| | "inline::faiss", | | +| | "remote::chromadb", | | +| | "remote::pgvector" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ +| cerebras | { | Use Cerebras for running LLM inference | +| | "inference": [ | | +| | "remote::cerebras" | | +| | ], | | +| | "safety": [ | | +| | "inline::llama-guard" | | +| | ], | | +| | "memory": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "agents": [ | | +| | "inline::meta-reference" | | +| | ], | | +| | "telemetry": [ | | +| | "inline::meta-reference" | | +| | ] | | +| | } | | ++------------------------------+----------------------------------------+-----------------------------------------------------------------------------+ ``` You may then pick a template to build your distribution with providers fitted to your liking. diff --git a/docs/source/distributions/self_hosted_distro/cerebras.md b/docs/source/distributions/self_hosted_distro/cerebras.md new file mode 100644 index 000000000..08b35809a --- /dev/null +++ b/docs/source/distributions/self_hosted_distro/cerebras.md @@ -0,0 +1,61 @@ +# Cerebras Distribution + +The `llamastack/distribution-cerebras` distribution consists of the following provider configurations. + +| API | Provider(s) | +|-----|-------------| +| agents | `inline::meta-reference` | +| inference | `remote::cerebras` | +| memory | `inline::meta-reference` | +| safety | `inline::llama-guard` | +| telemetry | `inline::meta-reference` | + + +### Environment Variables + +The following environment variables can be configured: + +- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) +- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``) + +### Models + +The following models are available by default: + +- `meta-llama/Llama-3.1-8B-Instruct (llama3.1-8b)` +- `meta-llama/Llama-3.1-70B-Instruct (llama3.1-70b)` + + +### Prerequisite: API Keys + +Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/). + + +## Running Llama Stack with Cerebras + +You can do this via Conda (build code) or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-cerebras \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY +``` + +### Via Conda + +```bash +llama stack build --template cerebras --image-type conda +llama stack run ./run.yaml \ + --port 5001 \ + --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY +``` diff --git a/docs/source/index.md b/docs/source/index.md index 291237843..abfaf51b4 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -45,6 +45,7 @@ Llama Stack already has a number of "adapters" available for some popular Infere | **API Provider** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | | Meta Reference | Single Node | Y | Y | Y | Y | Y | +| Cerebras | Single Node | | Y | | | | | Fireworks | Hosted | Y | Y | Y | | | | AWS Bedrock | Hosted | | Y | | Y | | | Together | Hosted | Y | Y | | Y | | diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index c8d061f6c..13d463ad8 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -61,6 +61,17 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.inference.sample.SampleConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="cerebras", + pip_packages=[ + "cerebras_cloud_sdk", + ], + module="llama_stack.providers.remote.inference.cerebras", + config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/cerebras/__init__.py b/llama_stack/providers/remote/inference/cerebras/__init__.py new file mode 100644 index 000000000..a24bb2c70 --- /dev/null +++ b/llama_stack/providers/remote/inference/cerebras/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import CerebrasImplConfig + + +async def get_adapter_impl(config: CerebrasImplConfig, _deps): + from .cerebras import CerebrasInferenceAdapter + + assert isinstance( + config, CerebrasImplConfig + ), f"Unexpected config type: {type(config)}" + + impl = CerebrasInferenceAdapter(config) + + await impl.initialize() + + return impl diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py new file mode 100644 index 000000000..65022f85e --- /dev/null +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator + +from cerebras.cloud.sdk import AsyncCerebras + +from llama_models.llama3.api.chat_format import ChatFormat + +from llama_models.llama3.api.datatypes import Message +from llama_models.llama3.api.tokenizer import Tokenizer + +from llama_stack.apis.inference import * # noqa: F403 + +from llama_models.datatypes import CoreModelId + +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + ModelRegistryHelper, +) +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, + completion_request_to_prompt, +) + +from .config import CerebrasImplConfig + + +model_aliases = [ + build_model_alias( + "llama3.1-8b", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "llama3.1-70b", + CoreModelId.llama3_1_70b_instruct.value, + ), +] + + +class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): + def __init__(self, config: CerebrasImplConfig) -> None: + ModelRegistryHelper.__init__( + self, + model_aliases=model_aliases, + ) + self.config = config + self.formatter = ChatFormat(Tokenizer.get_instance()) + + self.client = AsyncCerebras( + base_url=self.config.base_url, api_key=self.config.api_key + ) + + async def initialize(self) -> None: + return + + async def shutdown(self) -> None: + pass + + async def completion( + self, + model_id: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) + request = CompletionRequest( + model=model.provider_resource_id, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion( + request, + ) + else: + return await self._nonstream_completion(request) + + async def _nonstream_completion( + self, request: CompletionRequest + ) -> CompletionResponse: + params = self._get_params(request) + + r = await self.client.completions.create(**params) + + return process_completion_response(r, self.formatter) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = self._get_params(request) + + stream = await self.client.completions.create(**params) + + async for chunk in process_completion_stream_response(stream, self.formatter): + yield chunk + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) + request = ChatCompletionRequest( + model=model.provider_resource_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + + if stream: + return self._stream_chat_completion(request) + else: + return await self._nonstream_chat_completion(request) + + async def _nonstream_chat_completion( + self, request: CompletionRequest + ) -> CompletionResponse: + params = self._get_params(request) + + r = await self.client.completions.create(**params) + + return process_chat_completion_response(r, self.formatter) + + async def _stream_chat_completion( + self, request: CompletionRequest + ) -> AsyncGenerator: + params = self._get_params(request) + + stream = await self.client.completions.create(**params) + + async for chunk in process_chat_completion_stream_response( + stream, self.formatter + ): + yield chunk + + def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + if request.sampling_params and request.sampling_params.top_k: + raise ValueError("`top_k` not supported by Cerebras") + + prompt = "" + if type(request) == ChatCompletionRequest: + prompt = chat_completion_request_to_prompt( + request, self.get_llama_model(request.model), self.formatter + ) + elif type(request) == CompletionRequest: + prompt = completion_request_to_prompt(request, self.formatter) + else: + raise ValueError(f"Unknown request type {type(request)}") + + return { + "model": request.model, + "prompt": prompt, + "stream": request.stream, + **get_sampling_options(request.sampling_params), + } + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/cerebras/config.py b/llama_stack/providers/remote/inference/cerebras/config.py new file mode 100644 index 000000000..9bae6ca4d --- /dev/null +++ b/llama_stack/providers/remote/inference/cerebras/config.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any, Dict, Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + +DEFAULT_BASE_URL = "https://api.cerebras.ai" + + +@json_schema_type +class CerebrasImplConfig(BaseModel): + base_url: str = Field( + default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL), + description="Base URL for the Cerebras API", + ) + api_key: Optional[str] = Field( + default=os.environ.get("CEREBRAS_API_KEY"), + description="Cerebras API Key", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "base_url": DEFAULT_BASE_URL, + "api_key": "${env.CEREBRAS_API_KEY}", + } diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index a427eef12..21e122149 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.meta_reference import ( ) from llama_stack.providers.remote.inference.bedrock import BedrockConfig +from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig @@ -64,6 +65,21 @@ def inference_meta_reference(inference_model) -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_cerebras() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="cerebras", + provider_type="remote::cerebras", + config=CerebrasImplConfig( + api_key=get_env_or_fail("CEREBRAS_API_KEY"), + ).model_dump(), + ) + ], + ) + + @pytest.fixture(scope="session") def inference_ollama(inference_model) -> ProviderFixture: inference_model = ( @@ -206,6 +222,7 @@ INFERENCE_FIXTURES = [ "vllm_remote", "remote", "bedrock", + "cerebras", "nvidia", "tgi", ] diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 9e5c67375..aa2f0b413 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -94,6 +94,7 @@ class TestInference: "remote::tgi", "remote::together", "remote::fireworks", + "remote::cerebras", ): pytest.skip("Other inference providers don't support completion() yet") @@ -139,6 +140,7 @@ class TestInference: "remote::tgi", "remote::together", "remote::fireworks", + "remote::cerebras", ): pytest.skip( "Other inference providers don't support structured output in completions yet" diff --git a/llama_stack/templates/cerebras/__init__.py b/llama_stack/templates/cerebras/__init__.py new file mode 100644 index 000000000..9f9929b52 --- /dev/null +++ b/llama_stack/templates/cerebras/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .cerebras import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/cerebras/build.yaml b/llama_stack/templates/cerebras/build.yaml new file mode 100644 index 000000000..a1fe93099 --- /dev/null +++ b/llama_stack/templates/cerebras/build.yaml @@ -0,0 +1,17 @@ +version: '2' +name: cerebras +distribution_spec: + description: Use Cerebras for running LLM inference + docker_image: null + providers: + inference: + - remote::cerebras + safety: + - inline::llama-guard + memory: + - inline::meta-reference + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference +image_type: conda diff --git a/llama_stack/templates/cerebras/cerebras.py b/llama_stack/templates/cerebras/cerebras.py new file mode 100644 index 000000000..58e05adf8 --- /dev/null +++ b/llama_stack/templates/cerebras/cerebras.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_models.sku_list import all_registered_models + +from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput +from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig +from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases + +from llama_stack.templates.template import DistributionTemplate, RunConfigSettings + + +def get_distribution_template() -> DistributionTemplate: + providers = { + "inference": ["remote::cerebras"], + "safety": ["inline::llama-guard"], + "memory": ["inline::meta-reference"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + } + + inference_provider = Provider( + provider_id="cerebras", + provider_type="remote::cerebras", + config=CerebrasImplConfig.sample_run_config(), + ) + + core_model_to_hf_repo = { + m.descriptor(): m.huggingface_repo for m in all_registered_models() + } + default_models = [ + ModelInput( + model_id=core_model_to_hf_repo[m.llama_model], + provider_model_id=m.provider_model_id, + ) + for m in model_aliases + ] + + return DistributionTemplate( + name="cerebras", + distro_type="self_hosted", + description="Use Cerebras for running LLM inference", + docker_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + default_models=default_models, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider], + }, + default_models=default_models, + default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + ), + }, + run_config_env_vars={ + "LLAMASTACK_PORT": ( + "5001", + "Port for the Llama Stack distribution server", + ), + "CEREBRAS_API_KEY": ( + "", + "Cerebras API Key", + ), + }, + ) diff --git a/llama_stack/templates/cerebras/doc_template.md b/llama_stack/templates/cerebras/doc_template.md new file mode 100644 index 000000000..77fc6f478 --- /dev/null +++ b/llama_stack/templates/cerebras/doc_template.md @@ -0,0 +1,60 @@ +# Cerebras Distribution + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +{% if run_config_env_vars %} +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + +{% if default_models %} +### Models + +The following models are available by default: + +{% for model in default_models %} +- `{{ model.model_id }} ({{ model.provider_model_id }})` +{% endfor %} +{% endif %} + + +### Prerequisite: API Keys + +Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/). + + +## Running Llama Stack with Cerebras + +You can do this via Conda (build code) or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-{{ name }} \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY +``` + +### Via Conda + +```bash +llama stack build --template cerebras --image-type conda +llama stack run ./run.yaml \ + --port 5001 \ + --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY +``` diff --git a/llama_stack/templates/cerebras/run.yaml b/llama_stack/templates/cerebras/run.yaml new file mode 100644 index 000000000..0b41f5b76 --- /dev/null +++ b/llama_stack/templates/cerebras/run.yaml @@ -0,0 +1,63 @@ +version: '2' +image_name: cerebras +docker_image: null +conda_env: cerebras +apis: +- agents +- inference +- memory +- safety +- telemetry +providers: + inference: + - provider_id: cerebras + provider_type: remote::cerebras + config: + base_url: https://api.cerebras.ai + api_key: ${env.CEREBRAS_API_KEY} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + memory: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/faiss_store.db + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} +metadata_store: + namespace: null + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db +models: +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: null + provider_model_id: llama3.1-8b +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: null + provider_model_id: llama3.1-70b +shields: +- params: null + shield_id: meta-llama/Llama-Guard-3-8B + provider_id: null + provider_shield_id: null +memory_banks: [] +datasets: [] +scoring_fns: [] +eval_tasks: []