mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
fixes
This commit is contained in:
parent
9eb9a37ee4
commit
22981facb5
4 changed files with 10 additions and 14 deletions
|
|
@ -16,7 +16,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service |
|
||||||
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
||||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||||
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,10 @@
|
||||||
|
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .config import CerebrasImplConfig
|
from .config import CerebrasImplConfig
|
||||||
|
|
@ -18,17 +21,15 @@ class CerebrasInferenceAdapter(OpenAIMixin):
|
||||||
provider_data_api_key_field: str = "cerebras_api_key"
|
provider_data_api_key_field: str = "cerebras_api_key"
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
return self.config.api_key.get_secret_value()
|
if self.config.auth_credential is None:
|
||||||
|
raise ValueError("Cerebras API key is required")
|
||||||
|
return self.config.auth_credential.get_secret_value()
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return urljoin(self.config.base_url, "v1")
|
return urljoin(self.config.base_url, "v1")
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
input: str | list[str],
|
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
) -> OpenAIEmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
@ -28,10 +28,6 @@ class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||||
description="Base URL for the Cerebras API",
|
description="Base URL for the Cerebras API",
|
||||||
)
|
)
|
||||||
api_key: SecretStr = Field(
|
|
||||||
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
|
|
||||||
description="Cerebras API Key",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,6 @@ classifiers = [
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"databricks-sdk",
|
|
||||||
"fastapi>=0.115.0,<1.0", # server
|
"fastapi>=0.115.0,<1.0", # server
|
||||||
"fire", # for MCP in LLS client
|
"fire", # for MCP in LLS client
|
||||||
"httpx",
|
"httpx",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue