Merge branch 'main' into use-secret-str

This commit is contained in:
Matthew Farrellee 2025-10-08 08:44:28 -04:00
commit 39854f4562
14 changed files with 196 additions and 473 deletions

View file

@ -18,7 +18,7 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
| `project_id` | `str \| None` | No | | The Project ID key |
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
## Sample Configuration

View file

@ -611,7 +611,7 @@ class InferenceRouter(Inference):
completion_text += "".join(choice_data["content_parts"])
# Add metrics to the chunk
if self.telemetry and chunk.usage:
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,

View file

@ -3,3 +3,5 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .watsonx import get_distribution_template # noqa: F401

View file

@ -3,44 +3,33 @@ distribution_spec:
description: Use watsonx for running LLM inference
providers:
inference:
- provider_id: watsonx
provider_type: remote::watsonx
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
- provider_type: remote::watsonx
- provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
- provider_type: inline::faiss
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
- provider_type: inline::llama-guard
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_type: inline::meta-reference
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_type: inline::meta-reference
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_type: inline::meta-reference
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
- provider_id: localfs
provider_type: inline::localfs
- provider_type: remote::huggingface
- provider_type: inline::localfs
scoring:
- provider_id: basic
provider_type: inline::basic
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
- provider_type: inline::basic
- provider_type: inline::llm-as-judge
- provider_type: inline::braintrust
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol
files:
- provider_type: inline::localfs
image_type: venv
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]
- aiosqlite
- aiosqlite

View file

@ -4,13 +4,13 @@ apis:
- agents
- datasetio
- eval
- files
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
- files
providers:
inference:
- provider_id: watsonx
@ -19,8 +19,6 @@ providers:
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
api_key: ${env.WATSONX_API_KEY:=}
project_id: ${env.WATSONX_PROJECT_ID:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:
- provider_id: faiss
provider_type: inline::faiss
@ -48,7 +46,7 @@ providers:
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval:
@ -109,102 +107,7 @@ metadata_store:
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
models:
- metadata: {}
model_id: meta-llama/llama-3-3-70b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-2-13b-chat
provider_id: watsonx
provider_model_id: meta-llama/llama-2-13b-chat
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-2-13b
provider_id: watsonx
provider_model_id: meta-llama/llama-2-13b-chat
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-1-70b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-1-8b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-11b-vision-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-1b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-3b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-90b-vision-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-guard-3-11b-vision
provider_id: watsonx
provider_model_id: meta-llama/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-11B-Vision
provider_id: watsonx
provider_model_id: meta-llama/llama-guard-3-11b-vision
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
models: []
shields: []
vector_dbs: []
datasets: []

View file

@ -4,17 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
config=WatsonXConfig.sample_run_config(),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
available_models = {
"watsonx": MODEL_ENTRIES,
}
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
),
]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_models, _ = get_model_registry(available_models)
return DistributionTemplate(
name=name,
distro_type="remote_hosted",
description="Use watsonx for running LLM inference",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
template_path=None,
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
"inference": [inference_provider],
"files": [files_provider],
},
default_models=default_models + [embedding_model],
default_models=[],
default_tool_groups=default_tool_groups,
),
},

View file

@ -268,7 +268,7 @@ Available Models:
api=Api.inference,
adapter_type="watsonx",
provider_type="remote::watsonx",
pip_packages=["ibm_watsonx_ai"],
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",

View file

@ -4,19 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import WatsonXConfig
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
async def get_adapter_impl(config: WatsonXConfig, _deps):
# import dynamically so the import is used only when it is needed
from .watsonx import WatsonXInferenceAdapter
if not isinstance(config, WatsonXConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = WatsonXInferenceAdapter(config)
return adapter
__all__ = ["get_adapter_impl", "WatsonXConfig"]

View file

@ -7,16 +7,18 @@
import os
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
class WatsonXProviderDataValidator(BaseModel):
url: str
api_key: str
project_id: str
model_config = ConfigDict(
from_attributes=True,
extra="forbid",
)
watsonx_api_key: str | None
@json_schema_type
@ -26,8 +28,8 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
description="A base url for accessing the watsonx.ai",
)
project_id: str | None = Field(
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key",
default=None,
description="The watsonx.ai project ID",
)
timeout: int = Field(
default=60,

View file

@ -1,47 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/llama-3-3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-2-13b-chat",
CoreModelId.llama2_13b.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]

View file

@ -4,240 +4,120 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from ibm_watsonx_ai.foundation_models import Model
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI
import requests
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
GreedySamplingStrategy,
Inference,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
request_has_media,
)
from . import WatsonXConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::watsonx")
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
# Note on structured output
# WatsonX returns responses with a json embedded into a string.
# Examples:
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
_model_cache: dict[str, Model] = {}
# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n
# "first_name": "Michael",\n "last_name": "Jordan",\n'...)
# Not even a valid JSON, but we can still extract the JSON from the content
def __init__(self, config: WatsonXConfig):
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="watsonx",
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
provider_data_api_key_field="watsonx_api_key",
)
self.available_models = None
self.config = config
# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan",
# "year_born": "1963", "year_retired": "2003"\\}}$')
# Find the start of the boxed content
def get_base_url(self) -> str:
return self.config.url
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: WatsonXConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
self._config = config
self._openai_client: AsyncOpenAI | None = None
self._project_id = self._config.project_id
def _get_client(self, model_id) -> Model:
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
config_url = self._config.url
project_id = self._config.project_id
credentials = {"url": config_url, "apikey": config_api_key}
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
self._openai_client = AsyncOpenAI(
base_url=f"{self._config.url}/openai/v1",
api_key=self._config.api_key,
)
return self._openai_client
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
input_dict = {"params": {}}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
if request.sampling_params:
if request.sampling_params.strategy:
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
if request.sampling_params.max_tokens:
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["params"][GenParams.TEMPERATURE] = 0.0
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
params = {
**input_dict,
}
# Add watsonx.ai specific parameters
params["project_id"] = self.config.project_id
params["time_limit"] = self.config.timeout
return params
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
# Copied from OpenAIMixin
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from the provider's /v1/models.
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params) # type: ignore
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
if not self._model_cache:
await self.list_models()
return model in self._model_cache
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
async def list_models(self) -> list[Model] | None:
self._model_cache = {}
models = []
for model_spec in self._get_model_specs():
functions = [f["id"] for f in model_spec.get("functions", [])]
# Format: {"embedding_dimension": 1536, "context_length": 8192}
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
# watsonx.ai sometimes adds usage data to the stream
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
# Example of an embedding model:
# {'model_id': 'ibm/granite-embedding-278m-multilingual',
# 'label': 'granite-embedding-278m-multilingual',
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
# ...
provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
if "embedding" in functions:
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
context_length = model_spec["model_limits"]["max_sequence_length"]
embedding_metadata = {
"embedding_dimension": embedding_dimension,
"context_length": context_length,
}
model = Model(
identifier=model_spec["model_id"],
provider_resource_id=provider_resource_id,
provider_id=self.__provider_id__,
metadata=embedding_metadata,
model_type=ModelType.embedding,
)
self._model_cache[provider_resource_id] = model
models.append(model)
if "text_chat" in functions:
model = Model(
identifier=model_spec["model_id"],
provider_resource_id=provider_resource_id,
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
# In that case, the cache will record the generator Model object, and the list which we return will have
# both the generator Model object and the text chat Model object. That's fine because the cache is
# only used for check_model_availability() anyway.
self._model_cache[provider_resource_id] = model
models.append(model)
return models
seen_finish_reason = False
async for chunk in stream:
# Final usage chunk with no choices that the user didn't request, so discard
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break
yield chunk
for choice in chunk.choices:
if choice.finish_reason:
seen_finish_reason = True
break
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
# So we need to implement our own method to list models by calling the watsonx.ai API.
def _get_model_specs(self) -> list[dict[str, Any]]:
"""
Retrieves foundation model specifications from the watsonx.ai API.
"""
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
headers = {
# Note that there is no authorization header. Listing models does not require authentication.
"Content-Type": "application/json",
}
response = requests.get(url, headers=headers)
# --- Process the Response ---
# Raise an exception for bad status codes (4xx or 5xx)
response.raise_for_status()
# If the request is successful, parse and return the JSON response.
# The response should contain a list of model specifications
response_data = response.json()
if "resources" not in response_data:
raise ValueError("Resources not found in response")
return response_data["resources"]

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import struct
from collections.abc import AsyncIterator
from typing import Any
@ -16,6 +18,7 @@ from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
@ -26,7 +29,6 @@ from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
from llama_stack.providers.utils.inference.openai_compat import (
b64_encode_openai_embeddings_response,
convert_message_to_openai_dict_new,
convert_tooldef_to_openai_tool,
get_sampling_options,
@ -349,3 +351,28 @@ class LiteLLMOpenAIMixin(
return False
return model in litellm.models_by_provider[self.litellm_provider_name]
def b64_encode_openai_embeddings_response(
response_data: list[dict], encoding_format: str | None = "float"
) -> list[OpenAIEmbeddingData]:
"""
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
"""
data = []
for i, embedding_data in enumerate(response_data):
if encoding_format == "base64":
byte_array = bytearray()
for embedding_value in embedding_data["embedding"]:
byte_array.extend(struct.pack("f", float(embedding_value)))
response_embedding = base64.b64encode(byte_array).decode("utf-8")
else:
response_embedding = embedding_data["embedding"]
data.append(
OpenAIEmbeddingData(
embedding=response_embedding,
index=i,
)
)
return data

View file

@ -3,9 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import json
import struct
import time
import uuid
import warnings
@ -103,7 +101,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
Message,
OpenAIChatCompletion,
OpenAIEmbeddingData,
OpenAIMessageParam,
OpenAIResponseFormatParam,
SamplingParams,
@ -1402,28 +1399,3 @@ def prepare_openai_embeddings_params(
params["user"] = user
return params
def b64_encode_openai_embeddings_response(
response_data: dict, encoding_format: str | None = "float"
) -> list[OpenAIEmbeddingData]:
"""
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
"""
data = []
for i, embedding_data in enumerate(response_data):
if encoding_format == "base64":
byte_array = bytearray()
for embedding_value in embedding_data.embedding:
byte_array.extend(struct.pack("f", float(embedding_value)))
response_embedding = base64.b64encode(byte_array).decode("utf-8")
else:
response_embedding = embedding_data.embedding
data.append(
OpenAIEmbeddingData(
embedding=response_embedding,
index=i,
)
)
return data

View file

@ -18,6 +18,8 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
@pytest.mark.parametrize(
@ -58,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
assert inference_adapter.client.api_key == api_key
@pytest.mark.parametrize(
"config_cls,adapter_cls,provider_data_validator",
[
(
WatsonXConfig,
WatsonXInferenceAdapter,
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
),
],
)
def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
assumption that there is an OpenAI-compatible client object."""
inference_adapter = adapter_cls(config=config_cls())
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
assert inference_adapter.get_api_key() == api_key