mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Adding Cerebras Inference as an API provider. ## Testing ### Conda ``` $ llama stack build --template cerebras --image-type conda $ llama stack run ~/.llama/distributions/llamastack-cerebras/cerebras-run.yaml ... Listening on ['::', '0.0.0.0']:5000 INFO: Started server process [12443] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://['::', '0.0.0.0']:5000 (Press CTRL+C to quit) ``` ### Chat Completion ``` $ curl --location 'http://localhost:5000/alpha/inference/chat-completion' --header 'Content-Type: application/json' --data '{ "model_id": "meta-llama/Llama-3.1-8B-Instruct", "messages": [ { "role": "user", "content": "What is the temperature in Seattle right now?" } ], "stream": false, "sampling_params": { "strategy": "top_p", "temperature": 0.5, "max_tokens": 100 }, "tool_choice": "auto", "tool_prompt_format": "json", "tools": [ { "tool_name": "getTemperature", "description": "Gets the current temperature of a location.", "parameters": { "location": { "param_type": "string", "description": "The name of the place to get the temperature from in degress celsius.", "required": true } } } ] }' ``` #### Non-Streaming Response ``` { "completion_message": { "role": "assistant", "content": "", "stop_reason": "end_of_message", "tool_calls": [ { "call_id": "6f42fdcc-6cbb-46ad-a17b-5d20ac64b678", "tool_name": "getTemperature", "arguments": { "location": "Seattle" } } ] }, "logprobs": null } ``` #### Streaming Response ``` data: {"event":{"event_type":"start","delta":"","logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"","parse_status":"started"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"{\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"type","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"function","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"name","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"get","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"Temperature","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"parameters","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" {\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"location","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"Seattle","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\"}}","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":{"call_id":"e742df1f-0ae9-40ad-a49e-18e5c905484f","tool_name":"getTemperature","arguments":{"location":"Seattle"}},"parse_status":"success"},"logprobs":null,"stop_reason":"end_of_message"}} data: {"event":{"event_type":"complete","delta":"","logprobs":null,"stop_reason":"end_of_message"}} ``` ### Completion ``` $ curl --location 'http://localhost:5000/alpha/inference/completion' --header 'Content-Type: application/json' --data '{ "model_id": "meta-llama/Llama-3.1-8B-Instruct", "content": "1,2,3,", "stream": true, "sampling_params": { "strategy": "top_p", "temperature": 0.5, "max_tokens": 10 }, "tool_choice": "auto", "tool_prompt_format": "json", "tools": [ { "tool_name": "getTemperature", "description": "Gets the current temperature of a location.", "parameters": { "location": { "param_type": "string", "description": "The name of the place to get the temperature from in degress celsius.", "required": true } } } ] }' ``` #### Non-Streaming Response ``` { "content": "4,5,6,7,8,", "stop_reason": "out_of_tokens", "logprobs": null } ``` #### Streaming Response ``` data: {"delta":"4","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"5","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"6","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"7","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"8","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"","stop_reason":null,"logprobs":null} data: {"delta":"","stop_reason":"out_of_tokens","logprobs":null} ``` ### Pre-Commit Checks ``` trim trailing whitespace.................................................Passed check python ast.........................................................Passed check for merge conflicts................................................Passed check for added large files..............................................Passed fix end of files.........................................................Passed Insert license in comments...............................................Passed flake8...................................................................Passed Format files with µfmt...................................................Passed ``` ### Testing with `test_inference.py` ``` $ export CEREBRAS_API_KEY=<insert API key here> $ pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py -m "cerebras and llama_8b" /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/lib/python3.12/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) =================================================== test session starts =================================================== platform linux -- Python 3.12.3, pytest-8.3.3, pluggy-1.5.0 -- /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/bin/python3.12 cachedir: .pytest_cache rootdir: /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack configfile: pyproject.toml plugins: anyio-4.6.2.post1, asyncio-0.24.0 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 128 items / 120 deselected / 8 selected llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_8b-cerebras] Resolved 4 providers inner-inference => cerebras models => __routing_table__ inference => __autorouted__ inspect => __builtin__ Models: meta-llama/Llama-3.1-8B-Instruct served by cerebras PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completions_structured_output[llama_8b-cerebras] SKIPPED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-cerebras] SKIPPED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_8b-cerebras] PASSED ================================ 6 passed, 2 skipped, 120 deselected, 6 warnings in 3.95s ================================= ``` I ran `python llama_stack/scripts/distro_codegen.py` to run codegen.
175 lines
6.5 KiB
Python
175 lines
6.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import List
|
|
|
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
|
|
|
|
META_REFERENCE_DEPS = [
|
|
"accelerate",
|
|
"blobfile",
|
|
"fairscale",
|
|
"torch",
|
|
"torchvision",
|
|
"transformers",
|
|
"zmq",
|
|
"lm-format-enforcer",
|
|
]
|
|
|
|
|
|
def available_providers() -> List[ProviderSpec]:
|
|
return [
|
|
InlineProviderSpec(
|
|
api=Api.inference,
|
|
provider_type="inline::meta-reference",
|
|
pip_packages=META_REFERENCE_DEPS,
|
|
module="llama_stack.providers.inline.inference.meta_reference",
|
|
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
|
|
),
|
|
InlineProviderSpec(
|
|
api=Api.inference,
|
|
provider_type="inline::meta-reference-quantized",
|
|
pip_packages=(
|
|
META_REFERENCE_DEPS
|
|
+ [
|
|
"fbgemm-gpu",
|
|
"torchao==0.5.0",
|
|
]
|
|
),
|
|
module="llama_stack.providers.inline.inference.meta_reference",
|
|
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
|
|
),
|
|
InlineProviderSpec(
|
|
api=Api.inference,
|
|
provider_type="inline::vllm",
|
|
pip_packages=[
|
|
"vllm",
|
|
],
|
|
module="llama_stack.providers.inline.inference.vllm",
|
|
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
|
),
|
|
remote_provider_spec(
|
|
api=Api.inference,
|
|
adapter=AdapterSpec(
|
|
adapter_type="sample",
|
|
pip_packages=[],
|
|
module="llama_stack.providers.remote.inference.sample",
|
|
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
|
|
),
|
|
),
|
|
remote_provider_spec(
|
|
api=Api.inference,
|
|
adapter=AdapterSpec(
|
|
adapter_type="cerebras",
|
|
pip_packages=[
|
|
"cerebras_cloud_sdk",
|
|
],
|
|
module="llama_stack.providers.remote.inference.cerebras",
|
|
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
|
),
|
|
),
|
|
remote_provider_spec(
|
|
api=Api.inference,
|
|
adapter=AdapterSpec(
|
|
adapter_type="ollama",
|
|
pip_packages=["ollama", "aiohttp"],
|
|
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
|
|
module="llama_stack.providers.remote.inference.ollama",
|
|
),
|
|
),
|
|
remote_provider_spec(
|
|
api=Api.inference,
|
|
adapter=AdapterSpec(
|
|
adapter_type="vllm",
|
|
pip_packages=["openai"],
|
|
module="llama_stack.providers.remote.inference.vllm",
|
|
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
|
),
|
|
),
|
|
remote_provider_spec(
|
|
api=Api.inference,
|
|
adapter=AdapterSpec(
|
|
adapter_type="tgi",
|
|
pip_packages=["huggingface_hub", "aiohttp"],
|
|
module="llama_stack.providers.remote.inference.tgi",
|
|
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
|
|
),
|
|
),
|
|
remote_provider_spec(
|
|
api=Api.inference,
|
|
adapter=AdapterSpec(
|
|
adapter_type="hf::serverless",
|
|
pip_packages=["huggingface_hub", "aiohttp"],
|
|
module="llama_stack.providers.remote.inference.tgi",
|
|
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
|
|
),
|
|
),
|
|
remote_provider_spec(
|
|
api=Api.inference,
|
|
adapter=AdapterSpec(
|
|
adapter_type="hf::endpoint",
|
|
pip_packages=["huggingface_hub", "aiohttp"],
|
|
module="llama_stack.providers.remote.inference.tgi",
|
|
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
|
|
),
|
|
),
|
|
remote_provider_spec(
|
|
api=Api.inference,
|
|
adapter=AdapterSpec(
|
|
adapter_type="fireworks",
|
|
pip_packages=[
|
|
"fireworks-ai",
|
|
],
|
|
module="llama_stack.providers.remote.inference.fireworks",
|
|
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
|
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
|
),
|
|
),
|
|
remote_provider_spec(
|
|
api=Api.inference,
|
|
adapter=AdapterSpec(
|
|
adapter_type="together",
|
|
pip_packages=[
|
|
"together",
|
|
],
|
|
module="llama_stack.providers.remote.inference.together",
|
|
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
|
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
|
),
|
|
),
|
|
remote_provider_spec(
|
|
api=Api.inference,
|
|
adapter=AdapterSpec(
|
|
adapter_type="bedrock",
|
|
pip_packages=["boto3"],
|
|
module="llama_stack.providers.remote.inference.bedrock",
|
|
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
|
),
|
|
),
|
|
remote_provider_spec(
|
|
api=Api.inference,
|
|
adapter=AdapterSpec(
|
|
adapter_type="databricks",
|
|
pip_packages=[
|
|
"openai",
|
|
],
|
|
module="llama_stack.providers.remote.inference.databricks",
|
|
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
|
),
|
|
),
|
|
remote_provider_spec(
|
|
api=Api.inference,
|
|
adapter=AdapterSpec(
|
|
adapter_type="nvidia",
|
|
pip_packages=[
|
|
"openai",
|
|
],
|
|
module="llama_stack.providers.remote.inference.nvidia",
|
|
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
|
),
|
|
),
|
|
]
|