mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Adding Cerebras Inference as an API provider. ## Testing ### Conda ``` $ llama stack build --template cerebras --image-type conda $ llama stack run ~/.llama/distributions/llamastack-cerebras/cerebras-run.yaml ... Listening on ['::', '0.0.0.0']:5000 INFO: Started server process [12443] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://['::', '0.0.0.0']:5000 (Press CTRL+C to quit) ``` ### Chat Completion ``` $ curl --location 'http://localhost:5000/alpha/inference/chat-completion' --header 'Content-Type: application/json' --data '{ "model_id": "meta-llama/Llama-3.1-8B-Instruct", "messages": [ { "role": "user", "content": "What is the temperature in Seattle right now?" } ], "stream": false, "sampling_params": { "strategy": "top_p", "temperature": 0.5, "max_tokens": 100 }, "tool_choice": "auto", "tool_prompt_format": "json", "tools": [ { "tool_name": "getTemperature", "description": "Gets the current temperature of a location.", "parameters": { "location": { "param_type": "string", "description": "The name of the place to get the temperature from in degress celsius.", "required": true } } } ] }' ``` #### Non-Streaming Response ``` { "completion_message": { "role": "assistant", "content": "", "stop_reason": "end_of_message", "tool_calls": [ { "call_id": "6f42fdcc-6cbb-46ad-a17b-5d20ac64b678", "tool_name": "getTemperature", "arguments": { "location": "Seattle" } } ] }, "logprobs": null } ``` #### Streaming Response ``` data: {"event":{"event_type":"start","delta":"","logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"","parse_status":"started"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"{\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"type","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"function","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"name","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"get","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"Temperature","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\",","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"parameters","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" {\"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"location","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\":","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":" \"","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"Seattle","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":"\"}}","parse_status":"in_progress"},"logprobs":null,"stop_reason":null}} data: {"event":{"event_type":"progress","delta":{"content":{"call_id":"e742df1f-0ae9-40ad-a49e-18e5c905484f","tool_name":"getTemperature","arguments":{"location":"Seattle"}},"parse_status":"success"},"logprobs":null,"stop_reason":"end_of_message"}} data: {"event":{"event_type":"complete","delta":"","logprobs":null,"stop_reason":"end_of_message"}} ``` ### Completion ``` $ curl --location 'http://localhost:5000/alpha/inference/completion' --header 'Content-Type: application/json' --data '{ "model_id": "meta-llama/Llama-3.1-8B-Instruct", "content": "1,2,3,", "stream": true, "sampling_params": { "strategy": "top_p", "temperature": 0.5, "max_tokens": 10 }, "tool_choice": "auto", "tool_prompt_format": "json", "tools": [ { "tool_name": "getTemperature", "description": "Gets the current temperature of a location.", "parameters": { "location": { "param_type": "string", "description": "The name of the place to get the temperature from in degress celsius.", "required": true } } } ] }' ``` #### Non-Streaming Response ``` { "content": "4,5,6,7,8,", "stop_reason": "out_of_tokens", "logprobs": null } ``` #### Streaming Response ``` data: {"delta":"4","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"5","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"6","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"7","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"8","stop_reason":null,"logprobs":null} data: {"delta":",","stop_reason":null,"logprobs":null} data: {"delta":"","stop_reason":null,"logprobs":null} data: {"delta":"","stop_reason":"out_of_tokens","logprobs":null} ``` ### Pre-Commit Checks ``` trim trailing whitespace.................................................Passed check python ast.........................................................Passed check for merge conflicts................................................Passed check for added large files..............................................Passed fix end of files.........................................................Passed Insert license in comments...............................................Passed flake8...................................................................Passed Format files with µfmt...................................................Passed ``` ### Testing with `test_inference.py` ``` $ export CEREBRAS_API_KEY=<insert API key here> $ pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py -m "cerebras and llama_8b" /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/lib/python3.12/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) =================================================== test session starts =================================================== platform linux -- Python 3.12.3, pytest-8.3.3, pluggy-1.5.0 -- /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack/.venv/bin/python3.12 cachedir: .pytest_cache rootdir: /net/henryt-dev/srv/nfs/henryt-data/ws/llama-stack configfile: pyproject.toml plugins: anyio-4.6.2.post1, asyncio-0.24.0 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 128 items / 120 deselected / 8 selected llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[llama_8b-cerebras] Resolved 4 providers inner-inference => cerebras models => __routing_table__ inference => __autorouted__ inspect => __builtin__ Models: meta-llama/Llama-3.1-8B-Instruct served by cerebras PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completions_structured_output[llama_8b-cerebras] SKIPPED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[llama_8b-cerebras] SKIPPED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[llama_8b-cerebras] PASSED llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[llama_8b-cerebras] PASSED ================================ 6 passed, 2 skipped, 120 deselected, 6 warnings in 3.95s ================================= ``` I ran `python llama_stack/scripts/distro_codegen.py` to run codegen.
242 lines
6.9 KiB
Python
242 lines
6.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import os
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from llama_stack.apis.models import ModelInput
|
|
|
|
from llama_stack.distribution.datatypes import Api, Provider
|
|
from llama_stack.providers.inline.inference.meta_reference import (
|
|
MetaReferenceInferenceConfig,
|
|
)
|
|
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
|
|
|
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
|
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
|
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
|
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
|
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
|
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
|
|
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
|
from ..env import get_env_or_fail
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_model(request):
|
|
if hasattr(request, "param"):
|
|
return request.param
|
|
return request.config.getoption("--inference-model", None)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_remote() -> ProviderFixture:
|
|
return remote_stack_fixture()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_meta_reference(inference_model) -> ProviderFixture:
|
|
inference_model = (
|
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
|
)
|
|
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id=f"meta-reference-{i}",
|
|
provider_type="inline::meta-reference",
|
|
config=MetaReferenceInferenceConfig(
|
|
model=m,
|
|
max_seq_len=4096,
|
|
create_distributed_process_group=False,
|
|
checkpoint_dir=os.getenv("MODEL_CHECKPOINT_DIR", None),
|
|
).model_dump(),
|
|
)
|
|
for i, m in enumerate(inference_model)
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_cerebras() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="cerebras",
|
|
provider_type="remote::cerebras",
|
|
config=CerebrasImplConfig(
|
|
api_key=get_env_or_fail("CEREBRAS_API_KEY"),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_ollama(inference_model) -> ProviderFixture:
|
|
inference_model = (
|
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
|
)
|
|
if "Llama3.1-8B-Instruct" in inference_model:
|
|
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
|
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="ollama",
|
|
provider_type="remote::ollama",
|
|
config=OllamaImplConfig(
|
|
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_vllm_remote() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="remote::vllm",
|
|
provider_type="remote::vllm",
|
|
config=VLLMInferenceAdapterConfig(
|
|
url=get_env_or_fail("VLLM_URL"),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_fireworks() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="fireworks",
|
|
provider_type="remote::fireworks",
|
|
config=FireworksImplConfig(
|
|
api_key=get_env_or_fail("FIREWORKS_API_KEY"),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_together() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="together",
|
|
provider_type="remote::together",
|
|
config=TogetherImplConfig().model_dump(),
|
|
)
|
|
],
|
|
provider_data=dict(
|
|
together_api_key=get_env_or_fail("TOGETHER_API_KEY"),
|
|
),
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_bedrock() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="bedrock",
|
|
provider_type="remote::bedrock",
|
|
config=BedrockConfig().model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_nvidia() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="nvidia",
|
|
provider_type="remote::nvidia",
|
|
config=NVIDIAConfig().model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def inference_tgi() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="tgi",
|
|
provider_type="remote::tgi",
|
|
config=TGIImplConfig(
|
|
url=get_env_or_fail("TGI_URL"),
|
|
api_token=os.getenv("TGI_API_TOKEN", None),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
def get_model_short_name(model_name: str) -> str:
|
|
"""Convert model name to a short test identifier.
|
|
|
|
Args:
|
|
model_name: Full model name like "Llama3.1-8B-Instruct"
|
|
|
|
Returns:
|
|
Short name like "llama_8b" suitable for test markers
|
|
"""
|
|
model_name = model_name.lower()
|
|
if "vision" in model_name:
|
|
return "llama_vision"
|
|
elif "3b" in model_name:
|
|
return "llama_3b"
|
|
elif "8b" in model_name:
|
|
return "llama_8b"
|
|
else:
|
|
return model_name.replace(".", "_").replace("-", "_")
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def model_id(inference_model) -> str:
|
|
return get_model_short_name(inference_model)
|
|
|
|
|
|
INFERENCE_FIXTURES = [
|
|
"meta_reference",
|
|
"ollama",
|
|
"fireworks",
|
|
"together",
|
|
"vllm_remote",
|
|
"remote",
|
|
"bedrock",
|
|
"cerebras",
|
|
"nvidia",
|
|
"tgi",
|
|
]
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="session")
|
|
async def inference_stack(request, inference_model):
|
|
fixture_name = request.param
|
|
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
|
test_stack = await construct_stack_for_test(
|
|
[Api.inference],
|
|
{"inference": inference_fixture.providers},
|
|
inference_fixture.provider_data,
|
|
models=[ModelInput(model_id=inference_model)],
|
|
)
|
|
|
|
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|