This commit is contained in:
Ashwin Bharambe 2025-10-27 11:54:43 -07:00
parent 9eb9a37ee4
commit 22981facb5
4 changed files with 10 additions and 14 deletions

View file

@ -6,7 +6,10 @@
from urllib.parse import urljoin
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
from llama_stack.apis.inference import (
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import CerebrasImplConfig
@ -18,17 +21,15 @@ class CerebrasInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "cerebras_api_key"
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
if self.config.auth_credential is None:
raise ValueError("Cerebras API key is required")
return self.config.auth_credential.get_secret_value()
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -28,10 +28,6 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",
)
api_key: SecretStr = Field(
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
description="Cerebras API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]: