forked from phoenix-oss/llama-stack-mirror
		
	Cerebras Inference Integration (#265)
Adding Cerebras Inference as an API provider. ## Testing ### Conda ``` $ llama stack build --template cerebras --image-type conda $ llama stack run ~/.llama/distributions/llamastack-cerebras/cerebras-run.yaml ... Listening on ['::', '0.0.0.0']:5000 INFO: Started server process [12443] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://['::', '0.0.0.0']:5000 (Press CTRL+C to quit) ``` ### Chat Completion ``` $ curl --location 'http://localhost:5000/alpha/inference/chat-completion' --header 'Content-Type: application/json' --data '{ "model_id": "meta-llama/Llama-3.1-8B-Instruct", "messages": [ { "role": "user", "content": "What is the temperature in Seattle right now?" } ], "stream": false, "sampling_params": { "strategy": "top_p", "temperature": 0.5, "max_tokens": 100 }, "tool_choice": "auto", "tool_prompt_format": "json", "tools": [ { "tool_name": "getTemperature", "description": "Gets the current temperature of a location.", "parameters": { "location": { "param_type": "string", "description": "The name of the place to get the temperature from in degress celsius.", "required": true } } } ] }' ``` #### Non-Streaming Response ``` { "completion_message": { "role": "assistant", "content": "", "stop_reason": "end_of_message", "tool_calls": [ { "call_id": "6f42fdcc-6cbb-46ad-a17b-5d20ac64b678", "tool_name": "getTemperature", "arguments": { "location": "Seattle" } } ] }, "logprobs": null } ``` #### Streaming Response ``` data: {"event":{"event_type":"start","delta":"","logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"","parse_status":"started"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"{\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"type","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"function","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"name","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"get","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"Temperature","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"parameters","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" {\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"location","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"Seattle","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\"}}","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":{"call_id":"e742df1f-0ae9-40ad-a49e-18e5c905484f","tool_name":"getTemperature","arguments":{"location":"Seattle"}},"parse_status":"success"},"logprobs":null,"stop_reason":"end_of_message"}} data: {"event":{"event_type":"complete","delta":"","logprobs":null,"stop_reason":"end_of_message"}} ``` ### Completion ``` $ curl --location 'http://localhost:5000/alpha/inference/completion' --header 'Content-Type: application/json' --data '{ "model_id": "meta-llama/Llama-3.1-8B-Instruct", "content": "1,2,3,", "stream": true, "sampling_params": { "strategy": "top_p", "temperature": 0.5, "max_tokens": 10 }, "tool_choice": "auto", "tool_prompt_format": "json", "tools": [ { "tool_name": "getTemperature", "description": "Gets the current temperature of a location.", "parameters": { "location": { "param_type": "string", "description": "The name of the place to get the temperature from in degress celsius.", "required": true } } } ] }' ``` #### Non-Streaming Response ``` { "content": "4,5,6,7,8,", "stop_reason": "out_of_tokens", "logprobs": null } ``` #### Streaming Response ``` data: {"delta":"4","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"5","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"6","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"7","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"8","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"","stop_reason":null,"logprobs":null} data: {"delta":"","stop_reason":"out_of_tokens","logprobs":null} ``` ### Pre-Commit Checks ``` trim trailing whitespace.................................................Passed check python ast.........................................................Passed check for merge conflicts................................................Passed check for added large files..............................................Passed fix end of files.........................................................Passed Insert license in comments...............................................Passed flake8...................................................................Passed Format files with µfmt...................................................Passed ``` ### Testing with `test_inference.py` ``` $ export CEREBRAS_API_KEY=<insert API key here> $ pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py -m "cerebras and llama_8b" /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/lib/python3.12/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) =================================================== test session starts =================================================== platform linux -- Python 3.12.3, pytest-8.3.3, pluggy-1.5.0 -- /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/bin/python3.12 cachedir: .pytest_cache rootdir: /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack configfile: pyproject.toml plugins: anyio-4.6.2.post1, asyncio-0.24.0 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 128 items / 120 deselected / 8 selected llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_8b-cerebras] Resolved 4 providers inner-inference => cerebras models => __routing_table__ inference => __autorouted__ inspect => __builtin__ Models: meta-llama/Llama-3.1-8B-Instruct served by cerebras PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completions_structured_output[llama_8b-cerebras] SKIPPED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-cerebras] SKIPPED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_8b-cerebras] PASSED ================================ 6 passed, 2 skipped, 120 deselected, 6 warnings in 3.95s ================================= ``` I ran `python llama_stack/scripts/distro_codegen.py` to run codegen.
This commit is contained in:
		
							parent
							
								
									b6500974ec
								
							
						
					
					
						commit
						64c6df8392
					
				
					 19 changed files with 1018 additions and 292 deletions
				
			
		
							
								
								
									
										7
									
								
								llama_stack/templates/cerebras/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								llama_stack/templates/cerebras/__init__.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,7 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| from .cerebras import get_distribution_template  # noqa: F401 | ||||
							
								
								
									
										17
									
								
								llama_stack/templates/cerebras/build.yaml
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								llama_stack/templates/cerebras/build.yaml
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,17 @@ | |||
| version: '2' | ||||
| name: cerebras | ||||
| distribution_spec: | ||||
|   description: Use Cerebras for running LLM inference | ||||
|   docker_image: null | ||||
|   providers: | ||||
|     inference: | ||||
|     - remote::cerebras | ||||
|     safety: | ||||
|     - inline::llama-guard | ||||
|     memory: | ||||
|     - inline::meta-reference | ||||
|     agents: | ||||
|     - inline::meta-reference | ||||
|     telemetry: | ||||
|     - inline::meta-reference | ||||
| image_type: conda | ||||
							
								
								
									
										71
									
								
								llama_stack/templates/cerebras/cerebras.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								llama_stack/templates/cerebras/cerebras.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,71 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| from pathlib import Path | ||||
| 
 | ||||
| from llama_models.sku_list import all_registered_models | ||||
| 
 | ||||
| from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput | ||||
| from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig | ||||
| from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases | ||||
| 
 | ||||
| from llama_stack.templates.template import DistributionTemplate, RunConfigSettings | ||||
| 
 | ||||
| 
 | ||||
| def get_distribution_template() -> DistributionTemplate: | ||||
|     providers = { | ||||
|         "inference": ["remote::cerebras"], | ||||
|         "safety": ["inline::llama-guard"], | ||||
|         "memory": ["inline::meta-reference"], | ||||
|         "agents": ["inline::meta-reference"], | ||||
|         "telemetry": ["inline::meta-reference"], | ||||
|     } | ||||
| 
 | ||||
|     inference_provider = Provider( | ||||
|         provider_id="cerebras", | ||||
|         provider_type="remote::cerebras", | ||||
|         config=CerebrasImplConfig.sample_run_config(), | ||||
|     ) | ||||
| 
 | ||||
|     core_model_to_hf_repo = { | ||||
|         m.descriptor(): m.huggingface_repo for m in all_registered_models() | ||||
|     } | ||||
|     default_models = [ | ||||
|         ModelInput( | ||||
|             model_id=core_model_to_hf_repo[m.llama_model], | ||||
|             provider_model_id=m.provider_model_id, | ||||
|         ) | ||||
|         for m in model_aliases | ||||
|     ] | ||||
| 
 | ||||
|     return DistributionTemplate( | ||||
|         name="cerebras", | ||||
|         distro_type="self_hosted", | ||||
|         description="Use Cerebras for running LLM inference", | ||||
|         docker_image=None, | ||||
|         template_path=Path(__file__).parent / "doc_template.md", | ||||
|         providers=providers, | ||||
|         default_models=default_models, | ||||
|         run_configs={ | ||||
|             "run.yaml": RunConfigSettings( | ||||
|                 provider_overrides={ | ||||
|                     "inference": [inference_provider], | ||||
|                 }, | ||||
|                 default_models=default_models, | ||||
|                 default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], | ||||
|             ), | ||||
|         }, | ||||
|         run_config_env_vars={ | ||||
|             "LLAMASTACK_PORT": ( | ||||
|                 "5001", | ||||
|                 "Port for the Llama Stack distribution server", | ||||
|             ), | ||||
|             "CEREBRAS_API_KEY": ( | ||||
|                 "", | ||||
|                 "Cerebras API Key", | ||||
|             ), | ||||
|         }, | ||||
|     ) | ||||
							
								
								
									
										60
									
								
								llama_stack/templates/cerebras/doc_template.md
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								llama_stack/templates/cerebras/doc_template.md
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,60 @@ | |||
| # Cerebras Distribution | ||||
| 
 | ||||
| The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. | ||||
| 
 | ||||
| {{ providers_table }} | ||||
| 
 | ||||
| {% if run_config_env_vars %} | ||||
| ### Environment Variables | ||||
| 
 | ||||
| The following environment variables can be configured: | ||||
| 
 | ||||
| {% for var, (default_value, description) in run_config_env_vars.items() %} | ||||
| - `{{ var }}`: {{ description }} (default: `{{ default_value }}`) | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
| 
 | ||||
| {% if default_models %} | ||||
| ### Models | ||||
| 
 | ||||
| The following models are available by default: | ||||
| 
 | ||||
| {% for model in default_models %} | ||||
| - `{{ model.model_id }} ({{ model.provider_model_id }})` | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
| 
 | ||||
| 
 | ||||
| ### Prerequisite: API Keys | ||||
| 
 | ||||
| Make sure you have access to a Cerebras API Key. You can get one by visiting [cloud.cerebras.ai](https://cloud.cerebras.ai/). | ||||
| 
 | ||||
| 
 | ||||
| ## Running Llama Stack with Cerebras | ||||
| 
 | ||||
| You can do this via Conda (build code) or Docker which has a pre-built image. | ||||
| 
 | ||||
| ### Via Docker | ||||
| 
 | ||||
| This method allows you to get started quickly without having to build the distribution code. | ||||
| 
 | ||||
| ```bash | ||||
| LLAMA_STACK_PORT=5001 | ||||
| docker run \ | ||||
|   -it \ | ||||
|   -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ | ||||
|   -v ./run.yaml:/root/my-run.yaml \ | ||||
|   llamastack/distribution-{{ name }} \ | ||||
|   --yaml-config /root/my-run.yaml \ | ||||
|   --port $LLAMA_STACK_PORT \ | ||||
|   --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY | ||||
| ``` | ||||
| 
 | ||||
| ### Via Conda | ||||
| 
 | ||||
| ```bash | ||||
| llama stack build --template cerebras --image-type conda | ||||
| llama stack run ./run.yaml \ | ||||
|   --port 5001 \ | ||||
|   --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY | ||||
| ``` | ||||
							
								
								
									
										63
									
								
								llama_stack/templates/cerebras/run.yaml
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								llama_stack/templates/cerebras/run.yaml
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,63 @@ | |||
| version: '2' | ||||
| image_name: cerebras | ||||
| docker_image: null | ||||
| conda_env: cerebras | ||||
| apis: | ||||
| - agents | ||||
| - inference | ||||
| - memory | ||||
| - safety | ||||
| - telemetry | ||||
| providers: | ||||
|   inference: | ||||
|   - provider_id: cerebras | ||||
|     provider_type: remote::cerebras | ||||
|     config: | ||||
|       base_url: https://api.cerebras.ai | ||||
|       api_key: ${env.CEREBRAS_API_KEY} | ||||
|   safety: | ||||
|   - provider_id: llama-guard | ||||
|     provider_type: inline::llama-guard | ||||
|     config: {} | ||||
|   memory: | ||||
|   - provider_id: meta-reference | ||||
|     provider_type: inline::meta-reference | ||||
|     config: | ||||
|       kvstore: | ||||
|         type: sqlite | ||||
|         namespace: null | ||||
|         db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/faiss_store.db | ||||
|   agents: | ||||
|   - provider_id: meta-reference | ||||
|     provider_type: inline::meta-reference | ||||
|     config: | ||||
|       persistence_store: | ||||
|         type: sqlite | ||||
|         namespace: null | ||||
|         db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db | ||||
|   telemetry: | ||||
|   - provider_id: meta-reference | ||||
|     provider_type: inline::meta-reference | ||||
|     config: {} | ||||
| metadata_store: | ||||
|   namespace: null | ||||
|   type: sqlite | ||||
|   db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/registry.db | ||||
| models: | ||||
| - metadata: {} | ||||
|   model_id: meta-llama/Llama-3.1-8B-Instruct | ||||
|   provider_id: null | ||||
|   provider_model_id: llama3.1-8b | ||||
| - metadata: {} | ||||
|   model_id: meta-llama/Llama-3.1-70B-Instruct | ||||
|   provider_id: null | ||||
|   provider_model_id: llama3.1-70b | ||||
| shields: | ||||
| - params: null | ||||
|   shield_id: meta-llama/Llama-Guard-3-8B | ||||
|   provider_id: null | ||||
|   provider_shield_id: null | ||||
| memory_banks: [] | ||||
| datasets: [] | ||||
| scoring_fns: [] | ||||
| eval_tasks: [] | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue