llama-stack-mirror/llama_stack/providers/remote/inference/cerebras/cerebras.py
Matthew Farrellee ce7a3b4dff
feat: update Cerebras inference provider to support dynamic model listing (#3481)
# What does this PR do?

- update Cerebras to use OpenAIMixin
- enable openai completions tests
- enable openai chat completions tests
- disable with n > 1 tests
- add recording for --setup cerebras --subdirs inference --pattern
openai


## Test Plan

`./scripts/integration-tests.sh --stack-config server:ci-tests --setup
cerebras --subdirs inference --pattern openai`

```
tests/integration/inference/test_openai_completion.py::test_openai_completion_non_streaming[txt=cerebras/llama-3.3-70b-inference:completion:sanity] 
instantiating llama_stack_client
Port 8321 is already in use, assuming server is already running...
llama_stack_client instantiated in 0.053s
PASSED                                                                                            [  2%]
tests/integration/inference/test_openai_completion.py::test_openai_completion_non_streaming_suffix[txt=cerebras/llama-3.3-70b-inference:completion:suffix] SKIPPED (Suffix is not supported for the model: cerebras/llama-3.3-70b.)                   [  4%]
tests/integration/inference/test_openai_completion.py::test_openai_completion_streaming[txt=cerebras/llama-3.3-70b-inference:completion:sanity] PASSED                                                                                                [  6%]
tests/integration/inference/test_openai_completion.py::test_openai_completion_prompt_logprobs[txt=cerebras/llama-3.3-70b-1] SKIPPED (Model cerebras/llama-3.3-70b hosted by remote::cerebras doesn't support vllm extra_body parameters.)             [  8%]
tests/integration/inference/test_openai_completion.py::test_openai_completion_guided_choice[txt=cerebras/llama-3.3-70b] SKIPPED (Model cerebras/llama-3.3-70b hosted by remote::cerebras doesn't support vllm extra_body parameters.)                 [ 10%]
tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[openai_client-txt=cerebras/llama-3.3-70b-inference:chat_completion:non_streaming_01] PASSED                                                          [ 12%]
tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[openai_client-txt=cerebras/llama-3.3-70b-inference:chat_completion:streaming_01] PASSED                                                                  [ 14%]
tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming_with_n[openai_client-txt=cerebras/llama-3.3-70b-inference:chat_completion:streaming_01] SKIPPED (Model cerebras/llama-3.3-70b hosted by remote::cere...) [ 17%]
tests/integration/inference/test_openai_completion.py::test_inference_store[openai_client-txt=cerebras/llama-3.3-70b-True] PASSED                                                                                                                     [ 19%]
tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=cerebras/llama-3.3-70b-True] PASSED                                                                                                          [ 21%]
tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming_with_file[txt=cerebras/llama-3.3-70b] SKIPPED (Model cerebras/llama-3.3-70b hosted by remote::cerebras doesn't support chat completion calls wit...) [ 23%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_single_string[openai_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                               [ 25%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_multiple_strings[openai_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                            [ 27%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_encoding_format_float[openai_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                  [ 29%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_dimensions[openai_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                             [ 31%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_user_parameter[openai_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                         [ 34%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_empty_list_error[openai_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                            [ 36%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_invalid_model_error[openai_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                         [ 38%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_different_inputs_different_outputs[openai_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                          [ 40%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_encoding_format_base64[openai_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                 [ 42%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_base64_batch_processing[openai_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                     [ 44%]
tests/integration/inference/test_openai_completion.py::test_openai_completion_prompt_logprobs[txt=cerebras/llama-3.3-70b-0] SKIPPED (Model cerebras/llama-3.3-70b hosted by remote::cerebras doesn't support vllm extra_body parameters.)             [ 46%]
tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[openai_client-txt=cerebras/llama-3.3-70b-inference:chat_completion:non_streaming_02] PASSED                                                          [ 48%]
tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[openai_client-txt=cerebras/llama-3.3-70b-inference:chat_completion:streaming_02] PASSED                                                                  [ 51%]
tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming_with_n[openai_client-txt=cerebras/llama-3.3-70b-inference:chat_completion:streaming_02] SKIPPED (Model cerebras/llama-3.3-70b hosted by remote::cere...) [ 53%]
tests/integration/inference/test_openai_completion.py::test_inference_store[openai_client-txt=cerebras/llama-3.3-70b-False] PASSED                                                                                                                    [ 55%]
tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=cerebras/llama-3.3-70b-False] PASSED                                                                                                         [ 57%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_single_string[llama_stack_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                          [ 59%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_multiple_strings[llama_stack_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                       [ 61%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_encoding_format_float[llama_stack_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                             [ 63%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_dimensions[llama_stack_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                        [ 65%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_user_parameter[llama_stack_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                    [ 68%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_empty_list_error[llama_stack_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                       [ 70%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_invalid_model_error[llama_stack_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                    [ 72%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_different_inputs_different_outputs[llama_stack_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                     [ 74%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_with_encoding_format_base64[llama_stack_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                            [ 76%]
tests/integration/inference/test_openai_embeddings.py::test_openai_embeddings_base64_batch_processing[llama_stack_client-cerebras/llama-3.3-70b-None-None-None-384] SKIPPED (embedding_model_id empty - skipping test)                                [ 78%]
tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[client_with_models-txt=cerebras/llama-3.3-70b-inference:chat_completion:non_streaming_01] PASSED                                                     [ 80%]
tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[client_with_models-txt=cerebras/llama-3.3-70b-inference:chat_completion:streaming_01] PASSED                                                             [ 82%]
tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming_with_n[client_with_models-txt=cerebras/llama-3.3-70b-inference:chat_completion:streaming_01] SKIPPED (Model cerebras/llama-3.3-70b hosted by remote:...) [ 85%]
tests/integration/inference/test_openai_completion.py::test_inference_store[client_with_models-txt=cerebras/llama-3.3-70b-True] PASSED                                                                                                                [ 87%]
tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=cerebras/llama-3.3-70b-True] PASSED                                                                                                     [ 89%]
tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[client_with_models-txt=cerebras/llama-3.3-70b-inference:chat_completion:non_streaming_02] PASSED                                                     [ 91%]
tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[client_with_models-txt=cerebras/llama-3.3-70b-inference:chat_completion:streaming_02] PASSED                                                             [ 93%]
tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming_with_n[client_with_models-txt=cerebras/llama-3.3-70b-inference:chat_completion:streaming_02] SKIPPED (Model cerebras/llama-3.3-70b hosted by remote:...) [ 95%]
tests/integration/inference/test_openai_completion.py::test_inference_store[client_with_models-txt=cerebras/llama-3.3-70b-False] PASSED                                                                                                               [ 97%]
tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=cerebras/llama-3.3-70b-False] PASSED                                                                                                    [100%]

=================================================================================================================== slowest 10 durations ====================================================================================================================
0.37s call     tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[openai_client-txt=cerebras/llama-3.3-70b-inference:chat_completion:non_streaming_01]
0.34s call     tests/integration/inference/test_openai_completion.py::test_inference_store[openai_client-txt=cerebras/llama-3.3-70b-False]
0.18s call     tests/integration/inference/test_openai_completion.py::test_inference_store[client_with_models-txt=cerebras/llama-3.3-70b-True]
0.17s setup    tests/integration/inference/test_openai_completion.py::test_openai_completion_non_streaming[txt=cerebras/llama-3.3-70b-inference:completion:sanity]
0.15s call     tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=cerebras/llama-3.3-70b-True]
0.13s call     tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=cerebras/llama-3.3-70b-True]
0.12s call     tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=cerebras/llama-3.3-70b-False]
0.12s call     tests/integration/inference/test_openai_completion.py::test_inference_store[openai_client-txt=cerebras/llama-3.3-70b-True]
0.12s call     tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=cerebras/llama-3.3-70b-False]
0.08s call     tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[client_with_models-txt=cerebras/llama-3.3-70b-inference:chat_completion:streaming_02]
================================================================================================================== short test summary info ==================================================================================================================
SKIPPED [1] tests/integration/inference/test_openai_completion.py:75: Suffix is not supported for the model: cerebras/llama-3.3-70b.
SKIPPED [3] tests/integration/inference/test_openai_completion.py:123: Model cerebras/llama-3.3-70b hosted by remote::cerebras doesn't support vllm extra_body parameters.
SKIPPED [4] tests/integration/inference/test_openai_completion.py:103: Model cerebras/llama-3.3-70b hosted by remote::cerebras doesn't support n param.
SKIPPED [1] tests/integration/inference/test_openai_completion.py:129: Model cerebras/llama-3.3-70b hosted by remote::cerebras doesn't support chat completion calls with base64 encoded files.
SKIPPED [2] tests/integration/inference/test_openai_embeddings.py:90: embedding_model_id empty - skipping test
SKIPPED [2] tests/integration/inference/test_openai_embeddings.py:112: embedding_model_id empty - skipping test
SKIPPED [2] tests/integration/inference/test_openai_embeddings.py:136: embedding_model_id empty - skipping test
SKIPPED [2] tests/integration/inference/test_openai_embeddings.py:154: embedding_model_id empty - skipping test
SKIPPED [2] tests/integration/inference/test_openai_embeddings.py:175: embedding_model_id empty - skipping test
SKIPPED [2] tests/integration/inference/test_openai_embeddings.py:195: embedding_model_id empty - skipping test
SKIPPED [2] tests/integration/inference/test_openai_embeddings.py:206: embedding_model_id empty - skipping test
SKIPPED [2] tests/integration/inference/test_openai_embeddings.py:217: embedding_model_id empty - skipping test
SKIPPED [2] tests/integration/inference/test_openai_embeddings.py:244: embedding_model_id empty - skipping test
SKIPPED [2] tests/integration/inference/test_openai_embeddings.py:278: embedding_model_id empty - skipping test
================================================================================================= 18 passed, 29 skipped, 50 deselected, 4 warnings in 3.02s =================================================================================================
```
2025-09-23 16:26:00 -04:00

213 lines
6.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from urllib.parse import urljoin
from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import CerebrasImplConfig
from .models import MODEL_ENTRIES
class CerebrasInferenceAdapter(
OpenAIMixin,
ModelRegistryHelper,
Inference,
):
def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_entries=MODEL_ENTRIES,
)
self.config = config
# TODO: make this use provider data, etc. like other providers
self._cerebras_client = AsyncCerebras(
base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),
)
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(
request,
)
else:
return await self._nonstream_completion(request)
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self._cerebras_client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self._cerebras_client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self._cerebras_client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self._cerebras_client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
elif isinstance(request, CompletionRequest):
prompt = await completion_request_to_prompt(request)
else:
raise ValueError(f"Unknown request type {type(request)}")
return {
"model": request.model,
"prompt": prompt,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()