mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 04:22:35 +00:00
Merge branch 'main' into use-secret-str
This commit is contained in:
commit
39854f4562
14 changed files with 196 additions and 473 deletions
|
|
@ -18,7 +18,7 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
|
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
|
||||||
| `project_id` | `str \| None` | No | | The Project ID key |
|
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
|
||||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
||||||
|
|
@ -611,7 +611,7 @@ class InferenceRouter(Inference):
|
||||||
completion_text += "".join(choice_data["content_parts"])
|
completion_text += "".join(choice_data["content_parts"])
|
||||||
|
|
||||||
# Add metrics to the chunk
|
# Add metrics to the chunk
|
||||||
if self.telemetry and chunk.usage:
|
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
|
||||||
metrics = self._construct_metrics(
|
metrics = self._construct_metrics(
|
||||||
prompt_tokens=chunk.usage.prompt_tokens,
|
prompt_tokens=chunk.usage.prompt_tokens,
|
||||||
completion_tokens=chunk.usage.completion_tokens,
|
completion_tokens=chunk.usage.completion_tokens,
|
||||||
|
|
|
||||||
|
|
@ -3,3 +3,5 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .watsonx import get_distribution_template # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -3,44 +3,33 @@ distribution_spec:
|
||||||
description: Use watsonx for running LLM inference
|
description: Use watsonx for running LLM inference
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: watsonx
|
- provider_type: remote::watsonx
|
||||||
provider_type: remote::watsonx
|
- provider_type: inline::sentence-transformers
|
||||||
- provider_id: sentence-transformers
|
|
||||||
provider_type: inline::sentence-transformers
|
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_type: inline::faiss
|
||||||
provider_type: inline::faiss
|
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_type: inline::llama-guard
|
||||||
provider_type: inline::llama-guard
|
|
||||||
agents:
|
agents:
|
||||||
- provider_id: meta-reference
|
- provider_type: inline::meta-reference
|
||||||
provider_type: inline::meta-reference
|
|
||||||
telemetry:
|
telemetry:
|
||||||
- provider_id: meta-reference
|
- provider_type: inline::meta-reference
|
||||||
provider_type: inline::meta-reference
|
|
||||||
eval:
|
eval:
|
||||||
- provider_id: meta-reference
|
- provider_type: inline::meta-reference
|
||||||
provider_type: inline::meta-reference
|
|
||||||
datasetio:
|
datasetio:
|
||||||
- provider_id: huggingface
|
- provider_type: remote::huggingface
|
||||||
provider_type: remote::huggingface
|
- provider_type: inline::localfs
|
||||||
- provider_id: localfs
|
|
||||||
provider_type: inline::localfs
|
|
||||||
scoring:
|
scoring:
|
||||||
- provider_id: basic
|
- provider_type: inline::basic
|
||||||
provider_type: inline::basic
|
- provider_type: inline::llm-as-judge
|
||||||
- provider_id: llm-as-judge
|
- provider_type: inline::braintrust
|
||||||
provider_type: inline::llm-as-judge
|
|
||||||
- provider_id: braintrust
|
|
||||||
provider_type: inline::braintrust
|
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
- provider_type: remote::brave-search
|
- provider_type: remote::brave-search
|
||||||
- provider_type: remote::tavily-search
|
- provider_type: remote::tavily-search
|
||||||
- provider_type: inline::rag-runtime
|
- provider_type: inline::rag-runtime
|
||||||
- provider_type: remote::model-context-protocol
|
- provider_type: remote::model-context-protocol
|
||||||
|
files:
|
||||||
|
- provider_type: inline::localfs
|
||||||
image_type: venv
|
image_type: venv
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
- aiosqlite
|
|
||||||
- aiosqlite
|
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,13 @@ apis:
|
||||||
- agents
|
- agents
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
|
- files
|
||||||
- inference
|
- inference
|
||||||
- safety
|
- safety
|
||||||
- scoring
|
- scoring
|
||||||
- telemetry
|
- telemetry
|
||||||
- tool_runtime
|
- tool_runtime
|
||||||
- vector_io
|
- vector_io
|
||||||
- files
|
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: watsonx
|
- provider_id: watsonx
|
||||||
|
|
@ -19,8 +19,6 @@ providers:
|
||||||
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||||
api_key: ${env.WATSONX_API_KEY:=}
|
api_key: ${env.WATSONX_API_KEY:=}
|
||||||
project_id: ${env.WATSONX_PROJECT_ID:=}
|
project_id: ${env.WATSONX_PROJECT_ID:=}
|
||||||
- provider_id: sentence-transformers
|
|
||||||
provider_type: inline::sentence-transformers
|
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
|
@ -48,7 +46,7 @@ providers:
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
|
||||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||||
eval:
|
eval:
|
||||||
|
|
@ -109,102 +107,7 @@ metadata_store:
|
||||||
inference_store:
|
inference_store:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
|
||||||
models:
|
models: []
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-3-70b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-2-13b-chat
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-2-13b-chat
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-2-13b
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-2-13b-chat
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-1-70b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-1-8b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-2-11b-vision-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-2-1b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-2-3b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-2-90b-vision-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-guard-3-11b-vision
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
|
||||||
model_type: llm
|
|
||||||
- metadata:
|
|
||||||
embedding_dimension: 384
|
|
||||||
model_id: all-MiniLM-L6-v2
|
|
||||||
provider_id: sentence-transformers
|
|
||||||
model_type: embedding
|
|
||||||
shields: []
|
shields: []
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
|
||||||
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
|
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
|
||||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
|
||||||
SentenceTransformersInferenceConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||||
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
|
||||||
|
|
||||||
|
|
||||||
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||||
|
|
@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||||
config=WatsonXConfig.sample_run_config(),
|
config=WatsonXConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_provider = Provider(
|
|
||||||
provider_id="sentence-transformers",
|
|
||||||
provider_type="inline::sentence-transformers",
|
|
||||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
|
||||||
)
|
|
||||||
|
|
||||||
available_models = {
|
|
||||||
"watsonx": MODEL_ENTRIES,
|
|
||||||
}
|
|
||||||
default_tool_groups = [
|
default_tool_groups = [
|
||||||
ToolGroupInput(
|
ToolGroupInput(
|
||||||
toolgroup_id="builtin::websearch",
|
toolgroup_id="builtin::websearch",
|
||||||
|
|
@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
embedding_model = ModelInput(
|
|
||||||
model_id="all-MiniLM-L6-v2",
|
|
||||||
provider_id="sentence-transformers",
|
|
||||||
model_type=ModelType.embedding,
|
|
||||||
metadata={
|
|
||||||
"embedding_dimension": 384,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
files_provider = Provider(
|
files_provider = Provider(
|
||||||
provider_id="meta-reference-files",
|
provider_id="meta-reference-files",
|
||||||
provider_type="inline::localfs",
|
provider_type="inline::localfs",
|
||||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
)
|
)
|
||||||
default_models, _ = get_model_registry(available_models)
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
distro_type="remote_hosted",
|
distro_type="remote_hosted",
|
||||||
description="Use watsonx for running LLM inference",
|
description="Use watsonx for running LLM inference",
|
||||||
container_image=None,
|
container_image=None,
|
||||||
template_path=Path(__file__).parent / "doc_template.md",
|
template_path=None,
|
||||||
providers=providers,
|
providers=providers,
|
||||||
available_models_by_provider=available_models,
|
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider, embedding_provider],
|
"inference": [inference_provider],
|
||||||
"files": [files_provider],
|
"files": [files_provider],
|
||||||
},
|
},
|
||||||
default_models=default_models + [embedding_model],
|
default_models=[],
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -268,7 +268,7 @@ Available Models:
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter_type="watsonx",
|
adapter_type="watsonx",
|
||||||
provider_type="remote::watsonx",
|
provider_type="remote::watsonx",
|
||||||
pip_packages=["ibm_watsonx_ai"],
|
pip_packages=["litellm"],
|
||||||
module="llama_stack.providers.remote.inference.watsonx",
|
module="llama_stack.providers.remote.inference.watsonx",
|
||||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||||
|
|
|
||||||
|
|
@ -4,19 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.inference import Inference
|
|
||||||
|
|
||||||
from .config import WatsonXConfig
|
from .config import WatsonXConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
|
async def get_adapter_impl(config: WatsonXConfig, _deps):
|
||||||
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
# import dynamically so the import is used only when it is needed
|
||||||
from .watsonx import WatsonXInferenceAdapter
|
from .watsonx import WatsonXInferenceAdapter
|
||||||
|
|
||||||
if not isinstance(config, WatsonXConfig):
|
|
||||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
|
||||||
adapter = WatsonXInferenceAdapter(config)
|
adapter = WatsonXInferenceAdapter(config)
|
||||||
return adapter
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["get_adapter_impl", "WatsonXConfig"]
|
|
||||||
|
|
|
||||||
|
|
@ -7,16 +7,18 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class WatsonXProviderDataValidator(BaseModel):
|
class WatsonXProviderDataValidator(BaseModel):
|
||||||
url: str
|
model_config = ConfigDict(
|
||||||
api_key: str
|
from_attributes=True,
|
||||||
project_id: str
|
extra="forbid",
|
||||||
|
)
|
||||||
|
watsonx_api_key: str | None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
@ -26,8 +28,8 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
|
||||||
description="A base url for accessing the watsonx.ai",
|
description="A base url for accessing the watsonx.ai",
|
||||||
)
|
)
|
||||||
project_id: str | None = Field(
|
project_id: str | None = Field(
|
||||||
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
|
default=None,
|
||||||
description="The Project ID key",
|
description="The watsonx.ai project ID",
|
||||||
)
|
)
|
||||||
timeout: int = Field(
|
timeout: int = Field(
|
||||||
default=60,
|
default=60,
|
||||||
|
|
|
||||||
|
|
@ -1,47 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from llama_stack.models.llama.sku_types import CoreModelId
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
|
|
||||||
|
|
||||||
MODEL_ENTRIES = [
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-3-70b-instruct",
|
|
||||||
CoreModelId.llama3_3_70b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-2-13b-chat",
|
|
||||||
CoreModelId.llama2_13b.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-1-70b-instruct",
|
|
||||||
CoreModelId.llama3_1_70b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-1-8b-instruct",
|
|
||||||
CoreModelId.llama3_1_8b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-2-11b-vision-instruct",
|
|
||||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-2-1b-instruct",
|
|
||||||
CoreModelId.llama3_2_1b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-2-3b-instruct",
|
|
||||||
CoreModelId.llama3_2_3b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-3-2-90b-vision-instruct",
|
|
||||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"meta-llama/llama-guard-3-11b-vision",
|
|
||||||
CoreModelId.llama_guard_3_11b_vision.value,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
@ -4,240 +4,120 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ibm_watsonx_ai.foundation_models import Model
|
import requests
|
||||||
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import ChatCompletionRequest
|
||||||
ChatCompletionRequest,
|
from llama_stack.apis.models import Model
|
||||||
CompletionRequest,
|
from llama_stack.apis.models.models import ModelType
|
||||||
GreedySamplingStrategy,
|
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||||
Inference,
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAIChatCompletionChunk,
|
|
||||||
OpenAICompletion,
|
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
TopKSamplingStrategy,
|
|
||||||
TopPSamplingStrategy,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
|
||||||
prepare_openai_completion_params,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
chat_completion_request_to_prompt,
|
|
||||||
completion_request_to_prompt,
|
|
||||||
request_has_media,
|
|
||||||
)
|
|
||||||
|
|
||||||
from . import WatsonXConfig
|
|
||||||
from .models import MODEL_ENTRIES
|
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference::watsonx")
|
|
||||||
|
|
||||||
|
|
||||||
# Note on structured output
|
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
# WatsonX returns responses with a json embedded into a string.
|
_model_cache: dict[str, Model] = {}
|
||||||
# Examples:
|
|
||||||
|
|
||||||
# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n
|
def __init__(self, config: WatsonXConfig):
|
||||||
# "first_name": "Michael",\n "last_name": "Jordan",\n'...)
|
LiteLLMOpenAIMixin.__init__(
|
||||||
# Not even a valid JSON, but we can still extract the JSON from the content
|
self,
|
||||||
|
litellm_provider_name="watsonx",
|
||||||
# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan",
|
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
|
||||||
# "year_born": "1963", "year_retired": "2003"\\}}$')
|
provider_data_api_key_field="watsonx_api_key",
|
||||||
# Find the start of the boxed content
|
|
||||||
|
|
||||||
|
|
||||||
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|
||||||
def __init__(self, config: WatsonXConfig) -> None:
|
|
||||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
|
||||||
|
|
||||||
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
|
|
||||||
self._config = config
|
|
||||||
self._openai_client: AsyncOpenAI | None = None
|
|
||||||
|
|
||||||
self._project_id = self._config.project_id
|
|
||||||
|
|
||||||
def _get_client(self, model_id) -> Model:
|
|
||||||
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
|
||||||
config_url = self._config.url
|
|
||||||
project_id = self._config.project_id
|
|
||||||
credentials = {"url": config_url, "apikey": config_api_key}
|
|
||||||
|
|
||||||
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
|
|
||||||
|
|
||||||
def _get_openai_client(self) -> AsyncOpenAI:
|
|
||||||
if not self._openai_client:
|
|
||||||
self._openai_client = AsyncOpenAI(
|
|
||||||
base_url=f"{self._config.url}/openai/v1",
|
|
||||||
api_key=self._config.api_key,
|
|
||||||
)
|
)
|
||||||
return self._openai_client
|
self.available_models = None
|
||||||
|
self.config = config
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
|
def get_base_url(self) -> str:
|
||||||
input_dict = {"params": {}}
|
return self.config.url
|
||||||
media_present = request_has_media(request)
|
|
||||||
llama_model = self.get_llama_model(request.model)
|
|
||||||
if isinstance(request, ChatCompletionRequest):
|
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
|
||||||
else:
|
|
||||||
assert not media_present, "Together does not support media for Completion requests"
|
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
|
||||||
if request.sampling_params:
|
|
||||||
if request.sampling_params.strategy:
|
|
||||||
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
|
|
||||||
if request.sampling_params.max_tokens:
|
|
||||||
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
|
|
||||||
if request.sampling_params.repetition_penalty:
|
|
||||||
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
|
|
||||||
|
|
||||||
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||||
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
|
# Get base parameters from parent
|
||||||
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
|
params = await super()._get_params(request)
|
||||||
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
|
||||||
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
|
|
||||||
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
|
||||||
input_dict["params"][GenParams.TEMPERATURE] = 0.0
|
|
||||||
|
|
||||||
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
|
# Add watsonx.ai specific parameters
|
||||||
|
params["project_id"] = self.config.project_id
|
||||||
params = {
|
params["time_limit"] = self.config.timeout
|
||||||
**input_dict,
|
|
||||||
}
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
async def openai_embeddings(
|
# Copied from OpenAIMixin
|
||||||
self,
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
model: str,
|
"""
|
||||||
input: str | list[str],
|
Check if a specific model is available from the provider's /v1/models.
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def openai_completion(
|
:param model: The model identifier to check.
|
||||||
self,
|
:return: True if the model is available dynamically, False otherwise.
|
||||||
model: str,
|
"""
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
if not self._model_cache:
|
||||||
best_of: int | None = None,
|
await self.list_models()
|
||||||
echo: bool | None = None,
|
return model in self._model_cache
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
async def list_models(self) -> list[Model] | None:
|
||||||
logprobs: bool | None = None,
|
self._model_cache = {}
|
||||||
max_tokens: int | None = None,
|
models = []
|
||||||
n: int | None = None,
|
for model_spec in self._get_model_specs():
|
||||||
presence_penalty: float | None = None,
|
functions = [f["id"] for f in model_spec.get("functions", [])]
|
||||||
seed: int | None = None,
|
# Format: {"embedding_dimension": 1536, "context_length": 8192}
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
# Example of an embedding model:
|
||||||
stream_options: dict[str, Any] | None = None,
|
# {'model_id': 'ibm/granite-embedding-278m-multilingual',
|
||||||
temperature: float | None = None,
|
# 'label': 'granite-embedding-278m-multilingual',
|
||||||
top_p: float | None = None,
|
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
|
||||||
user: str | None = None,
|
# ...
|
||||||
guided_choice: list[str] | None = None,
|
provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
|
||||||
prompt_logprobs: int | None = None,
|
if "embedding" in functions:
|
||||||
suffix: str | None = None,
|
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
|
||||||
) -> OpenAICompletion:
|
context_length = model_spec["model_limits"]["max_sequence_length"]
|
||||||
model_obj = await self.model_store.get_model(model)
|
embedding_metadata = {
|
||||||
params = await prepare_openai_completion_params(
|
"embedding_dimension": embedding_dimension,
|
||||||
model=model_obj.provider_resource_id,
|
"context_length": context_length,
|
||||||
prompt=prompt,
|
}
|
||||||
best_of=best_of,
|
model = Model(
|
||||||
echo=echo,
|
identifier=model_spec["model_id"],
|
||||||
frequency_penalty=frequency_penalty,
|
provider_resource_id=provider_resource_id,
|
||||||
logit_bias=logit_bias,
|
provider_id=self.__provider_id__,
|
||||||
logprobs=logprobs,
|
metadata=embedding_metadata,
|
||||||
max_tokens=max_tokens,
|
model_type=ModelType.embedding,
|
||||||
n=n,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
)
|
||||||
return await self._get_openai_client().completions.create(**params) # type: ignore
|
self._model_cache[provider_resource_id] = model
|
||||||
|
models.append(model)
|
||||||
async def openai_chat_completion(
|
if "text_chat" in functions:
|
||||||
self,
|
model = Model(
|
||||||
model: str,
|
identifier=model_spec["model_id"],
|
||||||
messages: list[OpenAIMessageParam],
|
provider_resource_id=provider_resource_id,
|
||||||
frequency_penalty: float | None = None,
|
provider_id=self.__provider_id__,
|
||||||
function_call: str | dict[str, Any] | None = None,
|
metadata={},
|
||||||
functions: list[dict[str, Any]] | None = None,
|
model_type=ModelType.llm,
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
model_obj = await self.model_store.get_model(model)
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_obj.provider_resource_id,
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
)
|
||||||
if params.get("stream", False):
|
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
|
||||||
return self._stream_openai_chat_completion(params)
|
# In that case, the cache will record the generator Model object, and the list which we return will have
|
||||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
# both the generator Model object and the text chat Model object. That's fine because the cache is
|
||||||
|
# only used for check_model_availability() anyway.
|
||||||
|
self._model_cache[provider_resource_id] = model
|
||||||
|
models.append(model)
|
||||||
|
return models
|
||||||
|
|
||||||
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
|
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
|
||||||
# watsonx.ai sometimes adds usage data to the stream
|
# So we need to implement our own method to list models by calling the watsonx.ai API.
|
||||||
include_usage = False
|
def _get_model_specs(self) -> list[dict[str, Any]]:
|
||||||
if params.get("stream_options", None):
|
"""
|
||||||
include_usage = params["stream_options"].get("include_usage", False)
|
Retrieves foundation model specifications from the watsonx.ai API.
|
||||||
stream = await self._get_openai_client().chat.completions.create(**params)
|
"""
|
||||||
|
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
|
||||||
|
headers = {
|
||||||
|
# Note that there is no authorization header. Listing models does not require authentication.
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
seen_finish_reason = False
|
response = requests.get(url, headers=headers)
|
||||||
async for chunk in stream:
|
|
||||||
# Final usage chunk with no choices that the user didn't request, so discard
|
# --- Process the Response ---
|
||||||
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
|
# Raise an exception for bad status codes (4xx or 5xx)
|
||||||
break
|
response.raise_for_status()
|
||||||
yield chunk
|
|
||||||
for choice in chunk.choices:
|
# If the request is successful, parse and return the JSON response.
|
||||||
if choice.finish_reason:
|
# The response should contain a list of model specifications
|
||||||
seen_finish_reason = True
|
response_data = response.json()
|
||||||
break
|
if "resources" not in response_data:
|
||||||
|
raise ValueError("Resources not found in response")
|
||||||
|
return response_data["resources"]
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import struct
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -16,6 +18,7 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
|
OpenAIEmbeddingData,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
|
|
@ -26,7 +29,6 @@ from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
b64_encode_openai_embeddings_response,
|
|
||||||
convert_message_to_openai_dict_new,
|
convert_message_to_openai_dict_new,
|
||||||
convert_tooldef_to_openai_tool,
|
convert_tooldef_to_openai_tool,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
|
|
@ -349,3 +351,28 @@ class LiteLLMOpenAIMixin(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return model in litellm.models_by_provider[self.litellm_provider_name]
|
return model in litellm.models_by_provider[self.litellm_provider_name]
|
||||||
|
|
||||||
|
|
||||||
|
def b64_encode_openai_embeddings_response(
|
||||||
|
response_data: list[dict], encoding_format: str | None = "float"
|
||||||
|
) -> list[OpenAIEmbeddingData]:
|
||||||
|
"""
|
||||||
|
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
||||||
|
"""
|
||||||
|
data = []
|
||||||
|
for i, embedding_data in enumerate(response_data):
|
||||||
|
if encoding_format == "base64":
|
||||||
|
byte_array = bytearray()
|
||||||
|
for embedding_value in embedding_data["embedding"]:
|
||||||
|
byte_array.extend(struct.pack("f", float(embedding_value)))
|
||||||
|
|
||||||
|
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
||||||
|
else:
|
||||||
|
response_embedding = embedding_data["embedding"]
|
||||||
|
data.append(
|
||||||
|
OpenAIEmbeddingData(
|
||||||
|
embedding=response_embedding,
|
||||||
|
index=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import struct
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
|
@ -103,7 +101,6 @@ from llama_stack.apis.inference import (
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
Message,
|
Message,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIEmbeddingData,
|
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
|
@ -1402,28 +1399,3 @@ def prepare_openai_embeddings_params(
|
||||||
params["user"] = user
|
params["user"] = user
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def b64_encode_openai_embeddings_response(
|
|
||||||
response_data: dict, encoding_format: str | None = "float"
|
|
||||||
) -> list[OpenAIEmbeddingData]:
|
|
||||||
"""
|
|
||||||
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
|
||||||
"""
|
|
||||||
data = []
|
|
||||||
for i, embedding_data in enumerate(response_data):
|
|
||||||
if encoding_format == "base64":
|
|
||||||
byte_array = bytearray()
|
|
||||||
for embedding_value in embedding_data.embedding:
|
|
||||||
byte_array.extend(struct.pack("f", float(embedding_value)))
|
|
||||||
|
|
||||||
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
|
||||||
else:
|
|
||||||
response_embedding = embedding_data.embedding
|
|
||||||
data.append(
|
|
||||||
OpenAIEmbeddingData(
|
|
||||||
embedding=response_embedding,
|
|
||||||
index=i,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||||
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||||
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
|
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
|
||||||
|
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||||
|
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
@ -58,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida
|
||||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||||
):
|
):
|
||||||
assert inference_adapter.client.api_key == api_key
|
assert inference_adapter.client.api_key == api_key
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"config_cls,adapter_cls,provider_data_validator",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
WatsonXConfig,
|
||||||
|
WatsonXInferenceAdapter,
|
||||||
|
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
|
||||||
|
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
|
||||||
|
assumption that there is an OpenAI-compatible client object."""
|
||||||
|
|
||||||
|
inference_adapter = adapter_cls(config=config_cls())
|
||||||
|
|
||||||
|
inference_adapter.__provider_spec__ = MagicMock()
|
||||||
|
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||||
|
|
||||||
|
for api_key in ["test1", "test2"]:
|
||||||
|
with request_provider_data_context(
|
||||||
|
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||||
|
):
|
||||||
|
assert inference_adapter.get_api_key() == api_key
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue