mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
chore(package): migrate to src/ layout (#3920)
Migrates package structure to src/ layout following Python packaging best practices. All code moved from `llama_stack/` to `src/llama_stack/`. Public API unchanged - imports remain `import llama_stack.*`. Updated build configs, pre-commit hooks, scripts, and GitHub workflows accordingly. All hooks pass, package builds cleanly. **Developer note**: Reinstall after pulling: `pip install -e .`
This commit is contained in:
parent
98a5047f9d
commit
471b1b248b
791 changed files with 2983 additions and 456 deletions
5
src/llama_stack/providers/remote/inference/__init__.py
Normal file
5
src/llama_stack/providers/remote/inference/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import AnthropicConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: AnthropicConfig, _deps):
|
||||
from .anthropic import AnthropicInferenceAdapter
|
||||
|
||||
impl = AnthropicInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AnthropicConfig
|
||||
|
||||
|
||||
class AnthropicInferenceAdapter(OpenAIMixin):
|
||||
config: AnthropicConfig
|
||||
|
||||
provider_data_api_key_field: str = "anthropic_api_key"
|
||||
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
|
||||
# TODO: add support for voyageai, which is where these models are hosted
|
||||
# embedding_model_metadata = {
|
||||
# "voyage-3-large": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-3.5": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-3.5-lite": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-code-3": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-finance-2": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# "voyage-law-2": {"embedding_dimension": 1024, "context_length": 16000},
|
||||
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# }
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://api.anthropic.com/v1"
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class AnthropicProviderDataValidator(BaseModel):
|
||||
anthropic_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Anthropic models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AnthropicConfig(RemoteInferenceProviderConfig):
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
15
src/llama_stack/providers/remote/inference/azure/__init__.py
Normal file
15
src/llama_stack/providers/remote/inference/azure/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import AzureConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: AzureConfig, _deps):
|
||||
from .azure import AzureInferenceAdapter
|
||||
|
||||
impl = AzureInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
25
src/llama_stack/providers/remote/inference/azure/azure.py
Normal file
25
src/llama_stack/providers/remote/inference/azure/azure.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AzureConfig
|
||||
|
||||
|
||||
class AzureInferenceAdapter(OpenAIMixin):
|
||||
config: AzureConfig
|
||||
|
||||
provider_data_api_key_field: str = "azure_api_key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the Azure API base URL.
|
||||
|
||||
Returns the Azure API base URL from the configuration.
|
||||
"""
|
||||
return urljoin(str(self.config.api_base), "/openai/v1")
|
||||
61
src/llama_stack/providers/remote/inference/azure/config.py
Normal file
61
src/llama_stack/providers/remote/inference/azure/config.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class AzureProviderDataValidator(BaseModel):
|
||||
azure_api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
azure_api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
azure_api_version: str | None = Field(
|
||||
default=None,
|
||||
description="Azure API version for Azure (e.g., 2024-06-01)",
|
||||
)
|
||||
azure_api_type: str | None = Field(
|
||||
default="azure",
|
||||
description="Azure API type for Azure (e.g., azure)",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AzureConfig(RemoteInferenceProviderConfig):
|
||||
api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
|
||||
description="Azure API version for Azure (e.g., 2024-12-01-preview)",
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AZURE_API_TYPE", "azure"),
|
||||
description="Azure API type for Azure (e.g., azure)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
api_key: str = "${env.AZURE_API_KEY:=}",
|
||||
api_base: str = "${env.AZURE_API_BASE:=}",
|
||||
api_version: str = "${env.AZURE_API_VERSION:=}",
|
||||
api_type: str = "${env.AZURE_API_TYPE:=}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
"api_type": api_type,
|
||||
}
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from .config import BedrockConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BedrockConfig, _deps):
|
||||
from .bedrock import BedrockInferenceAdapter
|
||||
|
||||
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = BedrockInferenceAdapter(config)
|
||||
|
||||
await impl.initialize()
|
||||
|
||||
return impl
|
||||
142
src/llama_stack/providers/remote/inference/bedrock/bedrock.py
Normal file
142
src/llama_stack/providers/remote/inference/bedrock/bedrock.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from botocore.client import BaseClient
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
Inference,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_strategy_options,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
REGION_PREFIX_MAP = {
|
||||
"us": "us.",
|
||||
"eu": "eu.",
|
||||
"ap": "ap.",
|
||||
}
|
||||
|
||||
|
||||
def _get_region_prefix(region: str | None) -> str:
|
||||
# AWS requires region prefixes for inference profiles
|
||||
if region is None:
|
||||
return "us." # default to US when we don't know
|
||||
|
||||
# Handle case insensitive region matching
|
||||
region_lower = region.lower()
|
||||
for prefix in REGION_PREFIX_MAP:
|
||||
if region_lower.startswith(f"{prefix}-"):
|
||||
return REGION_PREFIX_MAP[prefix]
|
||||
|
||||
# Fallback to US for anything we don't recognize
|
||||
return "us."
|
||||
|
||||
|
||||
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
|
||||
# Return ARNs unchanged
|
||||
if model_id.startswith("arn:"):
|
||||
return model_id
|
||||
|
||||
# Return inference profile IDs that already have regional prefixes
|
||||
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
|
||||
return model_id
|
||||
|
||||
# Default to US East when no region is provided
|
||||
if region is None:
|
||||
region = "us-east-1"
|
||||
|
||||
return _get_region_prefix(region) + model_id
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self._config = config
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
if self._client is None:
|
||||
self._client = create_bedrock_client(self._config)
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
|
||||
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
|
||||
bedrock_model = request.model
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
options = get_sampling_strategy_options(sampling_params)
|
||||
|
||||
if sampling_params.max_tokens:
|
||||
options["max_gen_len"] = sampling_params.max_tokens
|
||||
if sampling_params.repetition_penalty > 0:
|
||||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||
|
||||
# Convert foundation model ID to inference profile ID
|
||||
region_name = self.client.meta.region_name
|
||||
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
|
||||
|
||||
return {
|
||||
"modelId": inference_profile_id,
|
||||
"body": json.dumps(
|
||||
{
|
||||
"prompt": prompt,
|
||||
**options,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
|
||||
11
src/llama_stack/providers/remote/inference/bedrock/config.py
Normal file
11
src/llama_stack/providers/remote/inference/bedrock/config.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||
|
||||
|
||||
class BedrockConfig(BedrockBaseConfig):
|
||||
pass
|
||||
29
src/llama_stack/providers/remote/inference/bedrock/models.py
Normal file
29
src/llama_stack/providers/remote/inference/bedrock/models.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: CerebrasImplConfig, _deps):
|
||||
from .cerebras import CerebrasInferenceAdapter
|
||||
|
||||
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = CerebrasInferenceAdapter(config=config)
|
||||
|
||||
await impl.initialize()
|
||||
|
||||
return impl
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
|
||||
|
||||
class CerebrasInferenceAdapter(OpenAIMixin):
|
||||
config: CerebrasImplConfig
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return urljoin(self.config.base_url, "v1")
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||
base_url: str = Field(
|
||||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"base_url": DEFAULT_BASE_URL,
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
||||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = DatabricksInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_HOST:=}",
|
||||
api_token: str = "${env.DATABRICKS_TOKEN:=}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
from llama_stack.apis.inference import OpenAICompletion, OpenAICompletionRequestWithExtraBody
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::databricks")
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(OpenAIMixin):
|
||||
config: DatabricksImplConfig
|
||||
|
||||
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
||||
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/serving-endpoints"
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [
|
||||
endpoint.name
|
||||
for endpoint in WorkspaceClient(
|
||||
host=self.config.url, token=self.get_api_key()
|
||||
).serving_endpoints.list() # TODO: this is not async
|
||||
]
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
||||
class FireworksProviderDataValidator(BaseModel):
|
||||
fireworks_api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: FireworksImplConfig, _deps):
|
||||
from .fireworks import FireworksInferenceAdapter
|
||||
|
||||
assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = FireworksInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FireworksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(OpenAIMixin):
|
||||
config: FireworksImplConfig
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
|
||||
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
|
||||
}
|
||||
|
||||
provider_data_api_key_field: str = "fireworks_api_key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return "https://api.fireworks.ai/inference/v1"
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import GeminiConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GeminiConfig, _deps):
|
||||
from .gemini import GeminiInferenceAdapter
|
||||
|
||||
impl = GeminiInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
28
src/llama_stack/providers/remote/inference/gemini/config.py
Normal file
28
src/llama_stack/providers/remote/inference/gemini/config.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GeminiProviderDataValidator(BaseModel):
|
||||
gemini_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Gemini models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GeminiConfig(RemoteInferenceProviderConfig):
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
82
src/llama_stack/providers/remote/inference/gemini/gemini.py
Normal file
82
src/llama_stack/providers/remote/inference/gemini/gemini.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import GeminiConfig
|
||||
|
||||
|
||||
class GeminiInferenceAdapter(OpenAIMixin):
|
||||
config: GeminiConfig
|
||||
|
||||
provider_data_api_key_field: str = "gemini_api_key"
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"models/text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
|
||||
"models/gemini-embedding-001": {"embedding_dimension": 3072, "context_length": 2048},
|
||||
}
|
||||
|
||||
def get_base_url(self):
|
||||
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
Override embeddings method to handle Gemini's missing usage statistics.
|
||||
Gemini's embedding API doesn't return usage information, so we provide default values.
|
||||
"""
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"model": await self._get_provider_model_id(params.model),
|
||||
"input": params.input,
|
||||
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
|
||||
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
|
||||
"user": params.user if params.user is not None else NOT_GIVEN,
|
||||
}
|
||||
|
||||
# Add extra_body if present
|
||||
extra_body = params.model_extra
|
||||
if extra_body:
|
||||
request_params["extra_body"] = extra_body
|
||||
|
||||
# Call OpenAI embeddings API with properly typed parameters
|
||||
response = await self.client.embeddings.create(**request_params)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
data.append(
|
||||
OpenAIEmbeddingData(
|
||||
embedding=embedding_data.embedding,
|
||||
index=i,
|
||||
)
|
||||
)
|
||||
|
||||
# Gemini doesn't return usage statistics - use default values
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
else:
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=params.model,
|
||||
usage=usage,
|
||||
)
|
||||
15
src/llama_stack/providers/remote/inference/groq/__init__.py
Normal file
15
src/llama_stack/providers/remote/inference/groq/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import GroqConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqConfig, _deps):
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqInferenceAdapter
|
||||
|
||||
adapter = GroqInferenceAdapter(config=config)
|
||||
return adapter
|
||||
34
src/llama_stack/providers/remote/inference/groq/config.py
Normal file
34
src/llama_stack/providers/remote/inference/groq/config.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for Groq models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GroqConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.groq.com",
|
||||
description="The URL for the Groq AI server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.groq.com",
|
||||
"api_key": api_key,
|
||||
}
|
||||
18
src/llama_stack/providers/remote/inference/groq/groq.py
Normal file
18
src/llama_stack/providers/remote/inference/groq/groq.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class GroqInferenceAdapter(OpenAIMixin):
|
||||
config: GroqConfig
|
||||
|
||||
provider_data_api_key_field: str = "groq_api_key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/openai/v1"
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import LlamaCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: LlamaCompatConfig, _deps):
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .llama import LlamaCompatInferenceAdapter
|
||||
|
||||
adapter = LlamaCompatInferenceAdapter(config=config)
|
||||
return adapter
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class LlamaProviderDataValidator(BaseModel):
|
||||
llama_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for api.llama models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.llama.com/compat/v1/",
|
||||
description="The URL for the Llama API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin):
|
||||
config: LlamaCompatConfig
|
||||
|
||||
provider_data_api_key_field: str = "llama_api_key"
|
||||
"""
|
||||
Llama API Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
||||
:return: The Llama API base URL
|
||||
"""
|
||||
return self.config.openai_compat_api_base
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
183
src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md
Normal file
183
src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
# NVIDIA Inference Provider for LlamaStack
|
||||
|
||||
This provider enables running inference using NVIDIA NIM.
|
||||
|
||||
## Features
|
||||
- Endpoints for completions, chat completions, and embeddings for registered models
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- LlamaStack with NVIDIA configuration
|
||||
- Access to NVIDIA NIM deployment
|
||||
- NIM for model to use for inference is deployed
|
||||
|
||||
### Setup
|
||||
|
||||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
||||
#### Initialize the client
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
os.environ["NVIDIA_API_KEY"] = (
|
||||
"" # Required if using hosted NIM endpoint. If self-hosted, not required.
|
||||
)
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nim.test" # NIM URL
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.initialize()
|
||||
```
|
||||
|
||||
### Create Chat Completion
|
||||
|
||||
The following example shows how to create a chat completion for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
response = client.chat.completions.create(
|
||||
model="nvidia/meta/llama-3.1-8b-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You must respond to each message with only one word",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Complete the sentence using one word: Roses are red, violets are:",
|
||||
},
|
||||
],
|
||||
stream=False,
|
||||
max_tokens=50,
|
||||
)
|
||||
print(f"Response: {response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
### Tool Calling Example ###
|
||||
|
||||
The following example shows how to do tool calling for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
tool_definition = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather information for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "Temperature unit (celsius or fahrenheit)",
|
||||
"default": "celsius",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool_response = client.chat.completions.create(
|
||||
model="nvidia/meta/llama-3.1-8b-instruct",
|
||||
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
|
||||
tools=[tool_definition],
|
||||
)
|
||||
|
||||
print(f"Response content: {tool_response.choices[0].message.content}")
|
||||
if tool_response.choices[0].message.tool_calls:
|
||||
for tool_call in tool_response.choices[0].message.tool_calls:
|
||||
print(f"Tool Called: {tool_call.function.name}")
|
||||
print(f"Arguments: {tool_call.function.arguments}")
|
||||
```
|
||||
|
||||
### Structured Output Example
|
||||
|
||||
The following example shows how to do structured output for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
person_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "number"},
|
||||
"occupation": {"type": "string"},
|
||||
},
|
||||
"required": ["name", "age", "occupation"],
|
||||
}
|
||||
|
||||
structured_response = client.chat.completions.create(
|
||||
model="nvidia/meta/llama-3.1-8b-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
|
||||
}
|
||||
],
|
||||
extra_body={"nvext": {"guided_json": person_schema}},
|
||||
)
|
||||
print(f"Structured Response: {structured_response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
### Create Embeddings
|
||||
|
||||
The following example shows how to create embeddings for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
response = client.embeddings.create(
|
||||
model="nvidia/nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
input=["What is the capital of France?"],
|
||||
extra_body={"input_type": "query"},
|
||||
)
|
||||
print(f"Embeddings: {response.data}")
|
||||
```
|
||||
|
||||
### Vision Language Models Example
|
||||
|
||||
The following example shows how to run vision inference by using an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
def load_image_as_base64(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
img_bytes = image_file.read()
|
||||
return base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
|
||||
image_path = {path_to_the_image}
|
||||
demo_image_b64 = load_image_as_base64(image_path)
|
||||
|
||||
vlm_response = client.chat.completions.create(
|
||||
model="nvidia/meta/llama-3.2-11b-vision-instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{demo_image_b64}",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Please describe what you see in this image in detail.",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
print(f"VLM Response: {vlm_response.choices[0].message.content}")
|
||||
```
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import NVIDIAConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
|
||||
# import dynamically so `llama stack list-deps` does not fail due to missing dependencies
|
||||
from .nvidia import NVIDIAInferenceAdapter
|
||||
|
||||
if not isinstance(config, NVIDIAConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
adapter = NVIDIAInferenceAdapter(config=config)
|
||||
await adapter.initialize()
|
||||
return adapter
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "NVIDIAConfig"]
|
||||
64
src/llama_stack/providers/remote/inference/nvidia/config.py
Normal file
64
src/llama_stack/providers/remote/inference/nvidia/config.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||
"""
|
||||
Configuration for the NVIDIA NIM inference endpoint.
|
||||
|
||||
Attributes:
|
||||
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
||||
api_key (str): The access key for the hosted NIM endpoints
|
||||
|
||||
There are two ways to access NVIDIA NIMs -
|
||||
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
||||
1. Self-hosted: You can run NVIDIA NIMs on your own infrastructure
|
||||
|
||||
By default the configuration is set to use the hosted APIs. This requires
|
||||
an API key which can be obtained from https://ngc.nvidia.com/.
|
||||
|
||||
By default the configuration will attempt to read the NVIDIA_API_KEY environment
|
||||
variable to set the api_key. Please do not put your API key in code.
|
||||
|
||||
If you are using a self-hosted NVIDIA NIM, you can set the url to the
|
||||
URL of your running NVIDIA NIM and do not need to set the api_key.
|
||||
"""
|
||||
|
||||
url: str = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
|
||||
description="A base url for accessing the NVIDIA NIM",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
description="Timeout for the HTTP requests",
|
||||
)
|
||||
append_api_version: bool = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}",
|
||||
api_key: str = "${env.NVIDIA_API_KEY:=}",
|
||||
append_api_version: bool = "${env.NVIDIA_APPEND_API_VERSION:=True}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"api_key": api_key,
|
||||
"append_api_version": append_api_version,
|
||||
}
|
||||
61
src/llama_stack/providers/remote/inference/nvidia/nvidia.py
Normal file
61
src/llama_stack/providers/remote/inference/nvidia/nvidia.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from . import NVIDIAConfig
|
||||
from .utils import _is_nvidia_hosted
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||
config: NVIDIAConfig
|
||||
|
||||
"""
|
||||
NVIDIA Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
||||
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
|
||||
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
|
||||
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
|
||||
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
|
||||
}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
|
||||
|
||||
if _is_nvidia_hosted(self.config):
|
||||
if not self.config.auth_credential:
|
||||
raise RuntimeError(
|
||||
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
|
||||
)
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
Get the API key for OpenAI mixin.
|
||||
|
||||
:return: The NVIDIA API key
|
||||
"""
|
||||
if self.config.auth_credential:
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
|
||||
if not _is_nvidia_hosted(self.config):
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
return None
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
||||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||
11
src/llama_stack/providers/remote/inference/nvidia/utils.py
Normal file
11
src/llama_stack/providers/remote/inference/nvidia/utils.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from . import NVIDIAConfig
|
||||
|
||||
|
||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||
return "integrate.api.nvidia.com" in config.url
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import OllamaImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OllamaImplConfig, _deps):
|
||||
from .ollama import OllamaInferenceAdapter
|
||||
|
||||
impl = OllamaInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
25
src/llama_stack/providers/remote/inference/ollama/config.py
Normal file
25
src/llama_stack/providers/remote/inference/ollama/config.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
}
|
||||
102
src/llama_stack/providers/remote/inference/ollama/ollama.py
Normal file
102
src/llama_stack/providers/remote/inference/ollama/ollama.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
from ollama import AsyncClient as AsyncOllamaClient
|
||||
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::ollama")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(OpenAIMixin):
|
||||
config: OllamaImplConfig
|
||||
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"all-minilm:l6-v2": {
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
"nomic-embed-text:latest": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
"nomic-embed-text:v1.5": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
"nomic-embed-text:137m-v1.5-fp16": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
}
|
||||
|
||||
download_images: bool = True
|
||||
_clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||
|
||||
@property
|
||||
def ollama_client(self) -> AsyncOllamaClient:
|
||||
# ollama client attaches itself to the current event loop (sadly?)
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop not in self._clients:
|
||||
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
|
||||
return self._clients[loop]
|
||||
|
||||
def get_api_key(self):
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.config.url.rstrip("/") + "/v1"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||
r = await self.health()
|
||||
if r["status"] == HealthStatus.ERROR:
|
||||
logger.warning(
|
||||
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the Ollama server.
|
||||
This method is used by initialize() and the Provider API to verify that the service is running
|
||||
correctly.
|
||||
Returns:
|
||||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
await self.ollama_client.ps()
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self._clients.clear()
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if await self.check_model_availability(model.provider_model_id):
|
||||
return model
|
||||
elif await self.check_model_availability(f"{model.provider_model_id}:latest"):
|
||||
model.provider_resource_id = f"{model.provider_model_id}:latest"
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_model_id}'"
|
||||
)
|
||||
return model
|
||||
|
||||
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import OpenAIConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OpenAIConfig, _deps):
|
||||
from .openai import OpenAIInferenceAdapter
|
||||
|
||||
impl = OpenAIInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
39
src/llama_stack/providers/remote/inference/openai/config.py
Normal file
39
src/llama_stack/providers/remote/inference/openai/config.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class OpenAIProviderDataValidator(BaseModel):
|
||||
openai_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for OpenAI models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIConfig(RemoteInferenceProviderConfig):
|
||||
base_url: str = Field(
|
||||
default="https://api.openai.com/v1",
|
||||
description="Base URL for OpenAI API",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
api_key: str = "${env.OPENAI_API_KEY:=}",
|
||||
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
38
src/llama_stack/providers/remote/inference/openai/openai.py
Normal file
38
src/llama_stack/providers/remote/inference/openai/openai.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import OpenAIConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::openai")
|
||||
|
||||
|
||||
#
|
||||
# This OpenAI adapter implements Inference methods using OpenAIMixin
|
||||
#
|
||||
class OpenAIInferenceAdapter(OpenAIMixin):
|
||||
"""
|
||||
OpenAI Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
||||
config: OpenAIConfig
|
||||
|
||||
provider_data_api_key_field: str = "openai_api_key"
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
|
||||
}
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the OpenAI API base URL.
|
||||
|
||||
Returns the OpenAI API base URL from the configuration.
|
||||
"""
|
||||
return self.config.base_url
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import PassthroughImplConfig
|
||||
|
||||
|
||||
class PassthroughProviderDataValidator(BaseModel):
|
||||
url: str
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: PassthroughImplConfig, _deps):
|
||||
from .passthrough import PassthroughInferenceAdapter
|
||||
|
||||
assert isinstance(config, PassthroughImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = PassthroughInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PassthroughImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default=None,
|
||||
description="The URL for the passthrough endpoint",
|
||||
)
|
||||
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="API Key for the passthrouth endpoint",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"url": url,
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack_client import AsyncLlamaStackClient
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.core.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .config import PassthroughImplConfig
|
||||
|
||||
|
||||
class PassthroughInferenceAdapter(Inference):
|
||||
def __init__(self, config: PassthroughImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self)
|
||||
self.config = config
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
def _get_client(self) -> AsyncLlamaStackClient:
|
||||
passthrough_url = None
|
||||
passthrough_api_key = None
|
||||
provider_data = None
|
||||
|
||||
if self.config.url is not None:
|
||||
passthrough_url = self.config.url
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.passthrough_url:
|
||||
raise ValueError(
|
||||
'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
|
||||
)
|
||||
passthrough_url = provider_data.passthrough_url
|
||||
|
||||
if self.config.api_key is not None:
|
||||
passthrough_api_key = self.config.api_key.get_secret_value()
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.passthrough_api_key:
|
||||
raise ValueError(
|
||||
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
|
||||
)
|
||||
passthrough_api_key = provider_data.passthrough_api_key
|
||||
|
||||
return AsyncLlamaStackClient(
|
||||
base_url=passthrough_url,
|
||||
api_key=passthrough_api_key,
|
||||
provider_data=provider_data,
|
||||
)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
client = self._get_client()
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
params = params.model_copy()
|
||||
params.model = model_obj.provider_resource_id
|
||||
|
||||
request_params = params.model_dump(exclude_none=True)
|
||||
|
||||
return await client.inference.openai_completion(**request_params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
client = self._get_client()
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
params = params.model_copy()
|
||||
params.model = model_obj.provider_resource_id
|
||||
|
||||
request_params = params.model_dump(exclude_none=True)
|
||||
|
||||
return await client.inference.openai_chat_completion(**request_params)
|
||||
|
||||
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
|
||||
json_params = {}
|
||||
for key, value in request_params.items():
|
||||
json_input = convert_pydantic_to_json_value(value)
|
||||
if isinstance(json_input, dict):
|
||||
json_input = {k: v for k, v in json_input.items() if v is not None}
|
||||
elif isinstance(json_input, list):
|
||||
json_input = [x for x in json_input if x is not None]
|
||||
new_input = []
|
||||
for x in json_input:
|
||||
if isinstance(x, dict):
|
||||
x = {k: v for k, v in x.items() if v is not None}
|
||||
new_input.append(x)
|
||||
json_input = new_input
|
||||
|
||||
json_params[key] = json_input
|
||||
|
||||
return json_params
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import RunpodImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RunpodImplConfig, _deps):
|
||||
from .runpod import RunpodInferenceAdapter
|
||||
|
||||
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = RunpodInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
32
src/llama_stack/providers/remote/inference/runpod/config.py
Normal file
32
src/llama_stack/providers/remote/inference/runpod/config.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL for the Runpod model serving endpoint",
|
||||
)
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.RUNPOD_URL:=}",
|
||||
"api_token": "${env.RUNPOD_API_TOKEN}",
|
||||
}
|
||||
42
src/llama_stack/providers/remote/inference/runpod/runpod.py
Normal file
42
src/llama_stack/providers/remote/inference/runpod/runpod.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import RunpodImplConfig
|
||||
|
||||
|
||||
class RunpodInferenceAdapter(OpenAIMixin):
|
||||
"""
|
||||
Adapter for RunPod's OpenAI-compatible API endpoints.
|
||||
Supports VLLM for serverless endpoint self-hosted or public endpoints.
|
||||
Can work with any runpod endpoints that support OpenAI-compatible API
|
||||
"""
|
||||
|
||||
config: RunpodImplConfig
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get base URL for OpenAI client."""
|
||||
return self.config.url
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Override to add RunPod-specific stream_options requirement."""
|
||||
params = params.model_copy()
|
||||
|
||||
if params.stream and not params.stream_options:
|
||||
params.stream_options = {"include_usage": True}
|
||||
|
||||
return await super().openai_chat_completion(params)
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
|
||||
from .sambanova import SambaNovaInferenceAdapter
|
||||
|
||||
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = SambaNovaInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class SambaNovaProviderDataValidator(BaseModel):
|
||||
sambanova_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="Sambanova Cloud API key",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova AI server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
|
||||
|
||||
class SambaNovaInferenceAdapter(OpenAIMixin):
|
||||
config: SambaNovaImplConfig
|
||||
|
||||
provider_data_api_key_field: str = "sambanova_api_key"
|
||||
download_images: bool = True # SambaNova does not support image downloads server-size, perform them on the client
|
||||
"""
|
||||
SambaNova Inference Adapter for Llama Stack.
|
||||
"""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the base URL for OpenAI mixin.
|
||||
|
||||
:return: The SambaNova base URL
|
||||
"""
|
||||
return self.config.url
|
||||
28
src/llama_stack/providers/remote/inference/tgi/__init__.py
Normal file
28
src/llama_stack/providers/remote/inference/tgi/__init__.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: InferenceAPIImplConfig | InferenceEndpointImplConfig | TGIImplConfig,
|
||||
_deps,
|
||||
):
|
||||
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
|
||||
|
||||
if isinstance(config, TGIImplConfig):
|
||||
impl = TGIAdapter()
|
||||
elif isinstance(config, InferenceAPIImplConfig):
|
||||
impl = InferenceAPIAdapter()
|
||||
elif isinstance(config, InferenceEndpointImplConfig):
|
||||
impl = InferenceEndpointAdapter()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid configuration. Expected 'TGIAdapter', 'InferenceAPIImplConfig' or 'InferenceEndpointImplConfig'. Got {type(config)}."
|
||||
)
|
||||
|
||||
await impl.initialize(config)
|
||||
return impl
|
||||
76
src/llama_stack/providers/remote/inference/tgi/config.py
Normal file
76
src/llama_stack/providers/remote/inference/tgi/config.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TGIImplConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
url: str = Field(
|
||||
description="The URL for the TGI serving endpoint",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.TGI_URL:=}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"url": url,
|
||||
}
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceEndpointImplConfig(BaseModel):
|
||||
endpoint_name: str = Field(
|
||||
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
|
||||
)
|
||||
api_token: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
endpoint_name: str = "${env.INFERENCE_ENDPOINT_NAME}",
|
||||
api_token: str = "${env.HF_API_TOKEN}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"endpoint_name": endpoint_name,
|
||||
"api_token": api_token,
|
||||
}
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceAPIImplConfig(BaseModel):
|
||||
huggingface_repo: str = Field(
|
||||
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
|
||||
)
|
||||
api_token: SecretStr | None = Field(
|
||||
default=None,
|
||||
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
repo: str = "${env.INFERENCE_MODEL}",
|
||||
api_token: str = "${env.HF_API_TOKEN}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"huggingface_repo": repo,
|
||||
"api_token": api_token,
|
||||
}
|
||||
85
src/llama_stack/providers/remote/inference/tgi/tgi.py
Normal file
85
src/llama_stack/providers/remote/inference/tgi/tgi.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
from pydantic import SecretStr
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
log = get_logger(name=__name__, category="inference::tgi")
|
||||
|
||||
|
||||
class _HfAdapter(OpenAIMixin):
|
||||
url: str
|
||||
api_key: SecretStr
|
||||
|
||||
hf_client: AsyncInferenceClient
|
||||
max_tokens: int
|
||||
model_id: str
|
||||
|
||||
overwrite_completion_id = True # TGI always returns id=""
|
||||
|
||||
def get_api_key(self):
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.url
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [self.model_id]
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
if not config.url:
|
||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
|
||||
endpoint_info = await self.hf_client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
self.url = f"{config.url.rstrip('/')}/v1"
|
||||
self.api_key = SecretStr("NO_KEY")
|
||||
|
||||
|
||||
class InferenceAPIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||
self.hf_client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
|
||||
endpoint_info = await self.hf_client.get_endpoint_info()
|
||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||
self.model_id = endpoint_info["model_id"]
|
||||
# TODO: how do we set url for this?
|
||||
|
||||
|
||||
class InferenceEndpointAdapter(_HfAdapter):
|
||||
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
|
||||
# Get the inference endpoint details
|
||||
api = HfApi(token=config.api_token.get_secret_value())
|
||||
endpoint = api.get_inference_endpoint(config.endpoint_name)
|
||||
# Wait for the endpoint to be ready (if not already)
|
||||
endpoint.wait(timeout=60)
|
||||
|
||||
# Initialize the adapter
|
||||
self.hf_client = endpoint.async_client
|
||||
self.model_id = endpoint.repository
|
||||
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
|
||||
# TODO: how do we set url for this?
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
||||
class TogetherProviderDataValidator(BaseModel):
|
||||
together_api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherImplConfig, _deps):
|
||||
from .together import TogetherInferenceAdapter
|
||||
|
||||
assert isinstance(config, TogetherImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = TogetherInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY:=}",
|
||||
}
|
||||
102
src/llama_stack/providers/remote/inference/together/together.py
Normal file
102
src/llama_stack/providers/remote/inference/together/together.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from together import AsyncTogether
|
||||
from together.constants import BASE_URL
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::together")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
||||
config: TogetherImplConfig
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
|
||||
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
|
||||
"BAAI/bge-base-en-v1.5": {"embedding_dimension": 768, "context_length": 512},
|
||||
"Alibaba-NLP/gte-modernbert-base": {"embedding_dimension": 768, "context_length": 8192},
|
||||
"intfloat/multilingual-e5-large-instruct": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
provider_data_api_key_field: str = "together_api_key"
|
||||
|
||||
def get_base_url(self):
|
||||
return BASE_URL
|
||||
|
||||
def _get_client(self) -> AsyncTogether:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.auth_credential.get_secret_value() if self.config.auth_credential else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
return AsyncTogether(api_key=together_api_key)
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
|
||||
return [m.id for m in await self._get_client().models.list()]
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
Together's OpenAI-compatible embeddings endpoint is not compatible with
|
||||
the standard OpenAI embeddings endpoint.
|
||||
|
||||
The endpoint -
|
||||
- not all models return usage information
|
||||
- does not support user param, returns 400 Unrecognized request arguments supplied: user
|
||||
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
|
||||
"""
|
||||
# Together support ticket #13332 -> will not fix
|
||||
if params.user is not None:
|
||||
raise ValueError("Together's embeddings endpoint does not support user param.")
|
||||
# Together support ticket #13333 -> escalated
|
||||
if params.dimensions is not None:
|
||||
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=await self._get_provider_model_id(params.model),
|
||||
input=params.input,
|
||||
encoding_format=params.encoding_format,
|
||||
)
|
||||
|
||||
response.model = (
|
||||
params.model
|
||||
) # return the user the same model id they provided, avoid exposing the provider model id
|
||||
|
||||
# Together support ticket #13330 -> escalated
|
||||
# - togethercomputer/m2-bert-80M-32k-retrieval *does not* return usage information
|
||||
if not hasattr(response, "usage") or response.usage is None:
|
||||
logger.warning(
|
||||
f"Together's embedding endpoint for {params.model} did not return usage information, substituting -1s."
|
||||
)
|
||||
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
||||
|
||||
return response # type: ignore[no-any-return]
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import VertexAIConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: VertexAIConfig, _deps):
|
||||
from .vertexai import VertexAIInferenceAdapter
|
||||
|
||||
impl = VertexAIInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class VertexAIProviderDataValidator(BaseModel):
|
||||
vertex_project: str | None = Field(
|
||||
default=None,
|
||||
description="Google Cloud project ID for Vertex AI",
|
||||
)
|
||||
vertex_location: str | None = Field(
|
||||
default=None,
|
||||
description="Google Cloud location for Vertex AI (e.g., us-central1)",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VertexAIConfig(RemoteInferenceProviderConfig):
|
||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||
|
||||
project: str = Field(
|
||||
description="Google Cloud project ID for Vertex AI",
|
||||
)
|
||||
location: str = Field(
|
||||
default="us-central1",
|
||||
description="Google Cloud location for Vertex AI",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
project: str = "${env.VERTEX_AI_PROJECT:=}",
|
||||
location: str = "${env.VERTEX_AI_LOCATION:=us-central1}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"project": project,
|
||||
"location": location,
|
||||
}
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import google.auth.transport.requests
|
||||
from google.auth import default
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import VertexAIConfig
|
||||
|
||||
|
||||
class VertexAIInferenceAdapter(OpenAIMixin):
|
||||
config: VertexAIConfig
|
||||
|
||||
provider_data_api_key_field: str = "vertex_project"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
"""
|
||||
Get an access token for Vertex AI using Application Default Credentials.
|
||||
|
||||
Vertex AI uses ADC instead of API keys. This method obtains an access token
|
||||
from the default credentials and returns it for use with the OpenAI-compatible client.
|
||||
"""
|
||||
try:
|
||||
# Get default credentials - will read from GOOGLE_APPLICATION_CREDENTIALS
|
||||
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
||||
credentials.refresh(google.auth.transport.requests.Request())
|
||||
return str(credentials.token)
|
||||
except Exception:
|
||||
# If we can't get credentials, return empty string to let the env work with ADC directly
|
||||
return ""
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the Vertex AI OpenAI-compatible API base URL.
|
||||
|
||||
Returns the Vertex AI OpenAI-compatible endpoint URL.
|
||||
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
|
||||
"""
|
||||
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
|
||||
22
src/llama_stack/providers/remote/inference/vllm/__init__.py
Normal file
22
src/llama_stack/providers/remote/inference/vllm/__init__.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
|
||||
class VLLMProviderDataValidator(BaseModel):
|
||||
vllm_api_token: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
|
||||
from .vllm import VLLMInferenceAdapter
|
||||
|
||||
assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = VLLMInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
59
src/llama_stack/providers/remote/inference/vllm/config.py
Normal file
59
src/llama_stack/providers/remote/inference/vllm/config.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, SecretStr, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
||||
url: str | None = Field(
|
||||
default=None,
|
||||
description="The URL for the vLLM model serving endpoint",
|
||||
)
|
||||
max_tokens: int = Field(
|
||||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
auth_credential: SecretStr | None = Field(
|
||||
default=None,
|
||||
alias="api_token",
|
||||
description="The API token",
|
||||
)
|
||||
tls_verify: bool | str = Field(
|
||||
default=True,
|
||||
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
|
||||
)
|
||||
|
||||
@field_validator("tls_verify")
|
||||
@classmethod
|
||||
def validate_tls_verify(cls, v):
|
||||
if isinstance(v, str):
|
||||
# Otherwise, treat it as a cert path
|
||||
cert_path = Path(v).expanduser().resolve()
|
||||
if not cert_path.exists():
|
||||
raise ValueError(f"TLS certificate file does not exist: {v}")
|
||||
if not cert_path.is_file():
|
||||
raise ValueError(f"TLS certificate path is not a file: {v}")
|
||||
return v
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.VLLM_URL:=}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"url": url,
|
||||
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
|
||||
"api_token": "${env.VLLM_API_TOKEN:=fake}",
|
||||
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
|
||||
}
|
||||
111
src/llama_stack/providers/remote/inference/vllm/vllm.py
Normal file
111
src/llama_stack/providers/remote/inference/vllm/vllm.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from collections.abc import AsyncIterator
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
log = get_logger(name=__name__, category="inference::vllm")
|
||||
|
||||
|
||||
class VLLMInferenceAdapter(OpenAIMixin):
|
||||
config: VLLMInferenceAdapterConfig
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
provider_data_api_key_field: str = "vllm_api_token"
|
||||
|
||||
def get_api_key(self) -> str | None:
|
||||
if self.config.auth_credential:
|
||||
return self.config.auth_credential.get_secret_value()
|
||||
return "NO KEY REQUIRED"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""Get the base URL from config."""
|
||||
if not self.config.url:
|
||||
raise ValueError("No base URL configured")
|
||||
return self.config.url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if not self.config.url:
|
||||
raise ValueError(
|
||||
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
||||
)
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the remote vLLM server.
|
||||
This method is used by the Provider API to verify
|
||||
that the service is running correctly.
|
||||
Uses the unauthenticated /health endpoint.
|
||||
Returns:
|
||||
|
||||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
base_url = self.get_base_url()
|
||||
health_url = urljoin(base_url, "health")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(health_url)
|
||||
response.raise_for_status()
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
def get_extra_client_params(self):
|
||||
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Skip the check when running without authentication.
|
||||
"""
|
||||
if not self.config.auth_credential:
|
||||
model_ids = []
|
||||
async for m in self.client.models.list():
|
||||
if m.id == model: # Found exact match
|
||||
return True
|
||||
model_ids.append(m.id)
|
||||
raise ValueError(f"Model '{model}' not found. Available models: {model_ids}")
|
||||
log.warning(f"Not checking model availability for {model} as API token may trigger OAuth workflow")
|
||||
return True
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
params = params.model_copy()
|
||||
|
||||
# Apply vLLM-specific defaults
|
||||
if params.max_tokens is None and self.config.max_tokens:
|
||||
params.max_tokens = self.config.max_tokens
|
||||
|
||||
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
||||
# References:
|
||||
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
# * https://github.com/vllm-project/vllm/pull/10000
|
||||
if not params.tools and params.tool_choice is not None:
|
||||
params.tool_choice = ToolChoice.none.value
|
||||
|
||||
return await super().openai_chat_completion(params)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import WatsonXConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: WatsonXConfig, _deps):
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .watsonx import WatsonXInferenceAdapter
|
||||
|
||||
adapter = WatsonXInferenceAdapter(config)
|
||||
return adapter
|
||||
45
src/llama_stack/providers/remote/inference/watsonx/config.py
Normal file
45
src/llama_stack/providers/remote/inference/watsonx/config.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class WatsonXProviderDataValidator(BaseModel):
|
||||
watsonx_project_id: str | None = Field(
|
||||
default=None,
|
||||
description="IBM WatsonX project ID",
|
||||
)
|
||||
watsonx_api_key: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WatsonXConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||
description="A base url for accessing the watsonx.ai",
|
||||
)
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="The watsonx.ai project ID",
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60,
|
||||
description="Timeout for the HTTP requests",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
|
||||
"api_key": "${env.WATSONX_API_KEY:=}",
|
||||
"project_id": "${env.WATSONX_PROJECT_ID:=}",
|
||||
}
|
||||
340
src/llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
340
src/llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionUsage,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.core.telemetry.tracing import get_current_span
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::remote::watsonx")
|
||||
|
||||
|
||||
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
provider_data_api_key_field: str = "watsonx_api_key"
|
||||
|
||||
def __init__(self, config: WatsonXConfig):
|
||||
self.available_models = None
|
||||
self.config = config
|
||||
api_key = config.auth_credential.get_secret_value() if config.auth_credential else None
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="watsonx",
|
||||
api_key_from_config=api_key,
|
||||
provider_data_api_key_field="watsonx_api_key",
|
||||
openai_compat_api_base=self.get_base_url(),
|
||||
)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Override parent method to add timeout and inject usage object when missing.
|
||||
This works around a LiteLLM defect where usage block is sometimes dropped.
|
||||
"""
|
||||
|
||||
# Add usage tracking for streaming when telemetry is active
|
||||
stream_options = params.stream_options
|
||||
if params.stream and get_current_span() is not None:
|
||||
if stream_options is None:
|
||||
stream_options = {"include_usage": True}
|
||||
elif "include_usage" not in stream_options:
|
||||
stream_options = {**stream_options, "include_usage": True}
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
messages=params.messages,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
function_call=params.function_call,
|
||||
functions=params.functions,
|
||||
logit_bias=params.logit_bias,
|
||||
logprobs=params.logprobs,
|
||||
max_completion_tokens=params.max_completion_tokens,
|
||||
max_tokens=params.max_tokens,
|
||||
n=params.n,
|
||||
parallel_tool_calls=params.parallel_tool_calls,
|
||||
presence_penalty=params.presence_penalty,
|
||||
response_format=params.response_format,
|
||||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=stream_options,
|
||||
temperature=params.temperature,
|
||||
tool_choice=params.tool_choice,
|
||||
tools=params.tools,
|
||||
top_logprobs=params.top_logprobs,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
# These are watsonx-specific parameters
|
||||
timeout=self.config.timeout,
|
||||
project_id=self.config.project_id,
|
||||
)
|
||||
|
||||
result = await litellm.acompletion(**request_params)
|
||||
|
||||
# If not streaming, check and inject usage if missing
|
||||
if not params.stream:
|
||||
# Use getattr to safely handle cases where usage attribute might not exist
|
||||
if getattr(result, "usage", None) is None:
|
||||
# Create usage object with zeros
|
||||
usage_obj = OpenAIChatCompletionUsage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
# Use model_copy to create a new response with the usage injected
|
||||
result = result.model_copy(update={"usage": usage_obj})
|
||||
return result
|
||||
|
||||
# For streaming, wrap the iterator to normalize chunks
|
||||
return self._normalize_stream(result)
|
||||
|
||||
def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk:
|
||||
"""
|
||||
Normalize a chunk to ensure it has all expected attributes.
|
||||
This works around LiteLLM not always including all expected attributes.
|
||||
"""
|
||||
# Ensure chunk has usage attribute with zeros if missing
|
||||
if not hasattr(chunk, "usage") or chunk.usage is None:
|
||||
usage_obj = OpenAIChatCompletionUsage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
chunk = chunk.model_copy(update={"usage": usage_obj})
|
||||
|
||||
# Ensure all delta objects in choices have expected attributes
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
normalized_choices = []
|
||||
for choice in chunk.choices:
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
delta = choice.delta
|
||||
# Build update dict for missing attributes
|
||||
delta_updates = {}
|
||||
if not hasattr(delta, "refusal"):
|
||||
delta_updates["refusal"] = None
|
||||
if not hasattr(delta, "reasoning_content"):
|
||||
delta_updates["reasoning_content"] = None
|
||||
|
||||
# If we need to update delta, create a new choice with updated delta
|
||||
if delta_updates:
|
||||
new_delta = delta.model_copy(update=delta_updates)
|
||||
new_choice = choice.model_copy(update={"delta": new_delta})
|
||||
normalized_choices.append(new_choice)
|
||||
else:
|
||||
normalized_choices.append(choice)
|
||||
else:
|
||||
normalized_choices.append(choice)
|
||||
|
||||
# If we modified any choices, create a new chunk with updated choices
|
||||
if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))):
|
||||
chunk = chunk.model_copy(update={"choices": normalized_choices})
|
||||
|
||||
return chunk
|
||||
|
||||
async def _normalize_stream(
|
||||
self, stream: AsyncIterator[OpenAIChatCompletionChunk]
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Normalize all chunks in the stream to ensure they have expected attributes.
|
||||
This works around LiteLLM sometimes not including expected attributes.
|
||||
"""
|
||||
try:
|
||||
async for chunk in stream:
|
||||
# Normalize and yield each chunk immediately
|
||||
yield self._normalize_chunk(chunk)
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing stream: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
"""
|
||||
Override parent method to add watsonx-specific parameters.
|
||||
"""
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
prompt=params.prompt,
|
||||
best_of=params.best_of,
|
||||
echo=params.echo,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
logit_bias=params.logit_bias,
|
||||
logprobs=params.logprobs,
|
||||
max_tokens=params.max_tokens,
|
||||
n=params.n,
|
||||
presence_penalty=params.presence_penalty,
|
||||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=params.stream_options,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
suffix=params.suffix,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
# These are watsonx-specific parameters
|
||||
timeout=self.config.timeout,
|
||||
project_id=self.config.project_id,
|
||||
)
|
||||
return await litellm.atext_completion(**request_params)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
Override parent method to add watsonx-specific parameters.
|
||||
"""
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
# Convert input to list if it's a string
|
||||
input_list = [params.input] if isinstance(params.input, str) else params.input
|
||||
|
||||
# Call litellm embedding function with watsonx-specific parameters
|
||||
response = litellm.embedding(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
input=input_list,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
dimensions=params.dimensions,
|
||||
# These are watsonx-specific parameters
|
||||
timeout=self.config.timeout,
|
||||
project_id=self.config.project_id,
|
||||
)
|
||||
|
||||
# Convert response to OpenAI format
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingUsage
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response
|
||||
|
||||
data = b64_encode_openai_embeddings_response(response.data, params.encoding_format)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||
total_tokens=response["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.provider_resource_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return self.config.url
|
||||
|
||||
# Copied from OpenAIMixin
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from the provider's /v1/models.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
if not self._model_cache:
|
||||
await self.list_models()
|
||||
return model in self._model_cache
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {}
|
||||
models = []
|
||||
for model_spec in self._get_model_specs():
|
||||
functions = [f["id"] for f in model_spec.get("functions", [])]
|
||||
# Format: {"embedding_dimension": 1536, "context_length": 8192}
|
||||
|
||||
# Example of an embedding model:
|
||||
# {'model_id': 'ibm/granite-embedding-278m-multilingual',
|
||||
# 'label': 'granite-embedding-278m-multilingual',
|
||||
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
|
||||
# ...
|
||||
provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
|
||||
if "embedding" in functions:
|
||||
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
|
||||
context_length = model_spec["model_limits"]["max_sequence_length"]
|
||||
embedding_metadata = {
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"context_length": context_length,
|
||||
}
|
||||
model = Model(
|
||||
identifier=model_spec["model_id"],
|
||||
provider_resource_id=provider_resource_id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata=embedding_metadata,
|
||||
model_type=ModelType.embedding,
|
||||
)
|
||||
self._model_cache[provider_resource_id] = model
|
||||
models.append(model)
|
||||
if "text_chat" in functions:
|
||||
model = Model(
|
||||
identifier=model_spec["model_id"],
|
||||
provider_resource_id=provider_resource_id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
|
||||
# In that case, the cache will record the generator Model object, and the list which we return will have
|
||||
# both the generator Model object and the text chat Model object. That's fine because the cache is
|
||||
# only used for check_model_availability() anyway.
|
||||
self._model_cache[provider_resource_id] = model
|
||||
models.append(model)
|
||||
return models
|
||||
|
||||
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
|
||||
# So we need to implement our own method to list models by calling the watsonx.ai API.
|
||||
def _get_model_specs(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieves foundation model specifications from the watsonx.ai API.
|
||||
"""
|
||||
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
|
||||
headers = {
|
||||
# Note that there is no authorization header. Listing models does not require authentication.
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
# --- Process the Response ---
|
||||
# Raise an exception for bad status codes (4xx or 5xx)
|
||||
response.raise_for_status()
|
||||
|
||||
# If the request is successful, parse and return the JSON response.
|
||||
# The response should contain a list of model specifications
|
||||
response_data = response.json()
|
||||
if "resources" not in response_data:
|
||||
raise ValueError("Resources not found in response")
|
||||
return response_data["resources"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue