diff --git a/docs/docs/providers/inference/remote_watsonx.mdx b/docs/docs/providers/inference/remote_watsonx.mdx index 33bc5bbc3..f081703ab 100644 --- a/docs/docs/providers/inference/remote_watsonx.mdx +++ b/docs/docs/providers/inference/remote_watsonx.mdx @@ -17,8 +17,8 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai | -| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key | -| `project_id` | `str \| None` | No | | The Project ID key | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx.ai API key | +| `project_id` | `str \| None` | No | | The watsonx.ai project ID | | `timeout` | `` | No | 60 | Timeout for the HTTP requests | ## Sample Configuration diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index c4338e614..847f6a2d2 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -611,7 +611,7 @@ class InferenceRouter(Inference): completion_text += "".join(choice_data["content_parts"]) # Add metrics to the chunk - if self.telemetry and chunk.usage: + if self.telemetry and hasattr(chunk, "usage") and chunk.usage: metrics = self._construct_metrics( prompt_tokens=chunk.usage.prompt_tokens, completion_tokens=chunk.usage.completion_tokens, diff --git a/llama_stack/distributions/watsonx/__init__.py b/llama_stack/distributions/watsonx/__init__.py index 756f351d8..078d86144 100644 --- a/llama_stack/distributions/watsonx/__init__.py +++ b/llama_stack/distributions/watsonx/__init__.py @@ -3,3 +3,5 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +from .watsonx import get_distribution_template # noqa: F401 diff --git a/llama_stack/distributions/watsonx/build.yaml b/llama_stack/distributions/watsonx/build.yaml index bf4be7eaf..06349a741 100644 --- a/llama_stack/distributions/watsonx/build.yaml +++ b/llama_stack/distributions/watsonx/build.yaml @@ -3,44 +3,33 @@ distribution_spec: description: Use watsonx for running LLM inference providers: inference: - - provider_id: watsonx - provider_type: remote::watsonx - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers + - provider_type: remote::watsonx + - provider_type: inline::sentence-transformers vector_io: - - provider_id: faiss - provider_type: inline::faiss + - provider_type: inline::faiss safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_type: inline::llama-guard agents: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - - provider_id: localfs - provider_type: inline::localfs + - provider_type: remote::huggingface + - provider_type: inline::localfs scoring: - - provider_id: basic - provider_type: inline::basic - - provider_id: llm-as-judge - provider_type: inline::llm-as-judge - - provider_id: braintrust - provider_type: inline::braintrust + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust tool_runtime: - provider_type: remote::brave-search - provider_type: remote::tavily-search - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol + files: + - provider_type: inline::localfs image_type: venv additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] -- aiosqlite -- aiosqlite diff --git a/llama_stack/distributions/watsonx/run.yaml b/llama_stack/distributions/watsonx/run.yaml index 92f367910..e0c337f9d 100644 --- a/llama_stack/distributions/watsonx/run.yaml +++ b/llama_stack/distributions/watsonx/run.yaml @@ -4,13 +4,13 @@ apis: - agents - datasetio - eval +- files - inference - safety - scoring - telemetry - tool_runtime - vector_io -- files providers: inference: - provider_id: watsonx @@ -19,8 +19,6 @@ providers: url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com} api_key: ${env.WATSONX_API_KEY:=} project_id: ${env.WATSONX_PROJECT_ID:=} - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers vector_io: - provider_id: faiss provider_type: inline::faiss @@ -48,7 +46,7 @@ providers: provider_type: inline::meta-reference config: service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" - sinks: ${env.TELEMETRY_SINKS:=console,sqlite} + sinks: ${env.TELEMETRY_SINKS:=sqlite} sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} eval: @@ -109,102 +107,7 @@ metadata_store: inference_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db -models: -- metadata: {} - model_id: meta-llama/llama-3-3-70b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-3-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.3-70B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-3-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-2-13b-chat - provider_id: watsonx - provider_model_id: meta-llama/llama-2-13b-chat - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-2-13b - provider_id: watsonx - provider_model_id: meta-llama/llama-2-13b-chat - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-1-70b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-1-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-70B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-1-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-1-8b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-1-8b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-8B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-1-8b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-2-11b-vision-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-11b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-11B-Vision-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-11b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-2-1b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-1b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-1B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-1b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-2-3b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-3b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-3B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-3b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-2-90b-vision-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-90b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-90B-Vision-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-90b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-guard-3-11b-vision - provider_id: watsonx - provider_model_id: meta-llama/llama-guard-3-11b-vision - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-Guard-3-11B-Vision - provider_id: watsonx - provider_model_id: meta-llama/llama-guard-3-11b-vision - model_type: llm -- metadata: - embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 - provider_id: sentence-transformers - model_type: embedding +models: [] shields: [] vector_dbs: [] datasets: [] diff --git a/llama_stack/distributions/watsonx/watsonx.py b/llama_stack/distributions/watsonx/watsonx.py index c3cab5d1b..645770612 100644 --- a/llama_stack/distributions/watsonx/watsonx.py +++ b/llama_stack/distributions/watsonx/watsonx.py @@ -4,17 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pathlib import Path -from llama_stack.apis.models import ModelType -from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput -from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry +from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig -from llama_stack.providers.inline.inference.sentence_transformers import ( - SentenceTransformersInferenceConfig, -) from llama_stack.providers.remote.inference.watsonx import WatsonXConfig -from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: @@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: config=WatsonXConfig.sample_run_config(), ) - embedding_provider = Provider( - provider_id="sentence-transformers", - provider_type="inline::sentence-transformers", - config=SentenceTransformersInferenceConfig.sample_run_config(), - ) - - available_models = { - "watsonx": MODEL_ENTRIES, - } default_tool_groups = [ ToolGroupInput( toolgroup_id="builtin::websearch", @@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: ), ] - embedding_model = ModelInput( - model_id="all-MiniLM-L6-v2", - provider_id="sentence-transformers", - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": 384, - }, - ) - files_provider = Provider( provider_id="meta-reference-files", provider_type="inline::localfs", config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"), ) - default_models, _ = get_model_registry(available_models) return DistributionTemplate( name=name, distro_type="remote_hosted", description="Use watsonx for running LLM inference", container_image=None, - template_path=Path(__file__).parent / "doc_template.md", + template_path=None, providers=providers, - available_models_by_provider=available_models, run_configs={ "run.yaml": RunConfigSettings( provider_overrides={ - "inference": [inference_provider, embedding_provider], + "inference": [inference_provider], "files": [files_provider], }, - default_models=default_models + [embedding_model], + default_models=[], default_tool_groups=default_tool_groups, ), }, diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index bf6a09b6c..f89565892 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -268,7 +268,7 @@ Available Models: api=Api.inference, adapter_type="watsonx", provider_type="remote::watsonx", - pip_packages=["ibm_watsonx_ai"], + pip_packages=["litellm"], module="llama_stack.providers.remote.inference.watsonx", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", diff --git a/llama_stack/providers/remote/inference/watsonx/__init__.py b/llama_stack/providers/remote/inference/watsonx/__init__.py index e59e873b6..35e74a720 100644 --- a/llama_stack/providers/remote/inference/watsonx/__init__.py +++ b/llama_stack/providers/remote/inference/watsonx/__init__.py @@ -4,19 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.inference import Inference - from .config import WatsonXConfig -async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference: - # import dynamically so `llama stack build` does not fail due to missing dependencies +async def get_adapter_impl(config: WatsonXConfig, _deps): + # import dynamically so the import is used only when it is needed from .watsonx import WatsonXInferenceAdapter - if not isinstance(config, WatsonXConfig): - raise RuntimeError(f"Unexpected config type: {type(config)}") adapter = WatsonXInferenceAdapter(config) return adapter - - -__all__ = ["get_adapter_impl", "WatsonXConfig"] diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 4bc0173c4..9e98d4003 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -7,16 +7,18 @@ import os from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import BaseModel, ConfigDict, Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type class WatsonXProviderDataValidator(BaseModel): - url: str - api_key: str - project_id: str + model_config = ConfigDict( + from_attributes=True, + extra="forbid", + ) + watsonx_api_key: str | None @json_schema_type @@ -25,13 +27,17 @@ class WatsonXConfig(RemoteInferenceProviderConfig): default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"), description="A base url for accessing the watsonx.ai", ) + # This seems like it should be required, but none of the other remote inference + # providers require it, so this is optional here too for consistency. + # The OpenAIConfig uses default=None instead, so this is following that precedent. api_key: SecretStr | None = Field( - default_factory=lambda: os.getenv("WATSONX_API_KEY"), - description="The watsonx API key", + default=None, + description="The watsonx.ai API key", ) + # As above, this is optional here too for consistency. project_id: str | None = Field( - default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"), - description="The Project ID key", + default=None, + description="The watsonx.ai project ID", ) timeout: int = Field( default=60, diff --git a/llama_stack/providers/remote/inference/watsonx/models.py b/llama_stack/providers/remote/inference/watsonx/models.py deleted file mode 100644 index d98f0510a..000000000 --- a/llama_stack/providers/remote/inference/watsonx/models.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry - -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "meta-llama/llama-3-3-70b-instruct", - CoreModelId.llama3_3_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-2-13b-chat", - CoreModelId.llama2_13b.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-1-70b-instruct", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-1-8b-instruct", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-11b-vision-instruct", - CoreModelId.llama3_2_11b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-1b-instruct", - CoreModelId.llama3_2_1b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-3b-instruct", - CoreModelId.llama3_2_3b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-90b-vision-instruct", - CoreModelId.llama3_2_90b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-guard-3-11b-vision", - CoreModelId.llama_guard_3_11b_vision.value, - ), -] diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index fc58691e2..d04472936 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,240 +4,120 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncGenerator, AsyncIterator from typing import Any -from ibm_watsonx_ai.foundation_models import Model -from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams -from openai import AsyncOpenAI +import requests -from llama_stack.apis.inference import ( - ChatCompletionRequest, - CompletionRequest, - GreedySamplingStrategy, - Inference, - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAICompletion, - OpenAIEmbeddingsResponse, - OpenAIMessageParam, - OpenAIResponseFormatParam, - TopKSamplingStrategy, - TopPSamplingStrategy, -) -from llama_stack.log import get_logger -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper -from llama_stack.providers.utils.inference.openai_compat import ( - prepare_openai_completion_params, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_prompt, - completion_request_to_prompt, - request_has_media, -) - -from . import WatsonXConfig -from .models import MODEL_ENTRIES - -logger = get_logger(name=__name__, category="inference::watsonx") +from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.apis.models import Model +from llama_stack.apis.models.models import ModelType +from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -# Note on structured output -# WatsonX returns responses with a json embedded into a string. -# Examples: +class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): + _model_cache: dict[str, Model] = {} -# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n -# "first_name": "Michael",\n "last_name": "Jordan",\n'...) -# Not even a valid JSON, but we can still extract the JSON from the content + def __init__(self, config: WatsonXConfig): + LiteLLMOpenAIMixin.__init__( + self, + litellm_provider_name="watsonx", + api_key_from_config=config.api_key.get_secret_value() if config.api_key else None, + provider_data_api_key_field="watsonx_api_key", + ) + self.available_models = None + self.config = config -# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan", -# "year_born": "1963", "year_retired": "2003"\\}}$') -# Find the start of the boxed content + def get_base_url(self) -> str: + return self.config.url + async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: + # Get base parameters from parent + params = await super()._get_params(request) -class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): - def __init__(self, config: WatsonXConfig) -> None: - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) - - logger.info(f"Initializing watsonx InferenceAdapter({config.url})...") - self._config = config - self._openai_client: AsyncOpenAI | None = None - - self._project_id = self._config.project_id - - def _get_client(self, model_id) -> Model: - config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None - config_url = self._config.url - project_id = self._config.project_id - credentials = {"url": config_url, "apikey": config_api_key} - - return Model(model_id=model_id, credentials=credentials, project_id=project_id) - - def _get_openai_client(self) -> AsyncOpenAI: - if not self._openai_client: - self._openai_client = AsyncOpenAI( - base_url=f"{self._config.url}/openai/v1", - api_key=self._config.api_key, - ) - return self._openai_client - - async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: - input_dict = {"params": {}} - media_present = request_has_media(request) - llama_model = self.get_llama_model(request.model) - if isinstance(request, ChatCompletionRequest): - input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) - else: - assert not media_present, "Together does not support media for Completion requests" - input_dict["prompt"] = await completion_request_to_prompt(request) - if request.sampling_params: - if request.sampling_params.strategy: - input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type - if request.sampling_params.max_tokens: - input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens - if request.sampling_params.repetition_penalty: - input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty - - if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): - input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p - input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature - if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): - input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k - if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): - input_dict["params"][GenParams.TEMPERATURE] = 0.0 - - input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"] - - params = { - **input_dict, - } + # Add watsonx.ai specific parameters + params["project_id"] = self.config.project_id + params["time_limit"] = self.config.timeout return params - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() + # Copied from OpenAIMixin + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from the provider's /v1/models. - async def openai_completion( - self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, - ) -> OpenAICompletion: - model_obj = await self.model_store.get_model(model) - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - ) - return await self._get_openai_client().completions.create(**params) # type: ignore + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + if not self._model_cache: + await self.list_models() + return model in self._model_cache - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - model_obj = await self.model_store.get_model(model) - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) - if params.get("stream", False): - return self._stream_openai_chat_completion(params) - return await self._get_openai_client().chat.completions.create(**params) # type: ignore + async def list_models(self) -> list[Model] | None: + self._model_cache = {} + models = [] + for model_spec in self._get_model_specs(): + functions = [f["id"] for f in model_spec.get("functions", [])] + # Format: {"embedding_dimension": 1536, "context_length": 8192} - async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator: - # watsonx.ai sometimes adds usage data to the stream - include_usage = False - if params.get("stream_options", None): - include_usage = params["stream_options"].get("include_usage", False) - stream = await self._get_openai_client().chat.completions.create(**params) + # Example of an embedding model: + # {'model_id': 'ibm/granite-embedding-278m-multilingual', + # 'label': 'granite-embedding-278m-multilingual', + # 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768}, + # ... + provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}" + if "embedding" in functions: + embedding_dimension = model_spec["model_limits"]["embedding_dimension"] + context_length = model_spec["model_limits"]["max_sequence_length"] + embedding_metadata = { + "embedding_dimension": embedding_dimension, + "context_length": context_length, + } + model = Model( + identifier=model_spec["model_id"], + provider_resource_id=provider_resource_id, + provider_id=self.__provider_id__, + metadata=embedding_metadata, + model_type=ModelType.embedding, + ) + self._model_cache[provider_resource_id] = model + models.append(model) + if "text_chat" in functions: + model = Model( + identifier=model_spec["model_id"], + provider_resource_id=provider_resource_id, + provider_id=self.__provider_id__, + metadata={}, + model_type=ModelType.llm, + ) + # In theory, I guess it is possible that a model could be both an embedding model and a text chat model. + # In that case, the cache will record the generator Model object, and the list which we return will have + # both the generator Model object and the text chat Model object. That's fine because the cache is + # only used for check_model_availability() anyway. + self._model_cache[provider_resource_id] = model + models.append(model) + return models - seen_finish_reason = False - async for chunk in stream: - # Final usage chunk with no choices that the user didn't request, so discard - if not include_usage and seen_finish_reason and len(chunk.choices) == 0: - break - yield chunk - for choice in chunk.choices: - if choice.finish_reason: - seen_finish_reason = True - break + # LiteLLM provides methods to list models for many providers, but not for watsonx.ai. + # So we need to implement our own method to list models by calling the watsonx.ai API. + def _get_model_specs(self) -> list[dict[str, Any]]: + """ + Retrieves foundation model specifications from the watsonx.ai API. + """ + url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25" + headers = { + # Note that there is no authorization header. Listing models does not require authentication. + "Content-Type": "application/json", + } + + response = requests.get(url, headers=headers) + + # --- Process the Response --- + # Raise an exception for bad status codes (4xx or 5xx) + response.raise_for_status() + + # If the request is successful, parse and return the JSON response. + # The response should contain a list of model specifications + response_data = response.json() + if "resources" not in response_data: + raise ValueError("Resources not found in response") + return response_data["resources"] diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 6c8f61c3b..6bef97dd5 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 +import struct from collections.abc import AsyncIterator from typing import Any @@ -16,6 +18,7 @@ from llama_stack.apis.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, + OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, OpenAIMessageParam, @@ -26,7 +29,6 @@ from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry from llama_stack.providers.utils.inference.openai_compat import ( - b64_encode_openai_embeddings_response, convert_message_to_openai_dict_new, convert_tooldef_to_openai_tool, get_sampling_options, @@ -349,3 +351,28 @@ class LiteLLMOpenAIMixin( return False return model in litellm.models_by_provider[self.litellm_provider_name] + + +def b64_encode_openai_embeddings_response( + response_data: list[dict], encoding_format: str | None = "float" +) -> list[OpenAIEmbeddingData]: + """ + Process the OpenAI embeddings response to encode the embeddings in base64 format if specified. + """ + data = [] + for i, embedding_data in enumerate(response_data): + if encoding_format == "base64": + byte_array = bytearray() + for embedding_value in embedding_data["embedding"]: + byte_array.extend(struct.pack("f", float(embedding_value))) + + response_embedding = base64.b64encode(byte_array).decode("utf-8") + else: + response_embedding = embedding_data["embedding"] + data.append( + OpenAIEmbeddingData( + embedding=response_embedding, + index=i, + ) + ) + return data diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index d863eb53a..7e465a14c 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -3,9 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import base64 import json -import struct import time import uuid import warnings @@ -103,7 +101,6 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, Message, OpenAIChatCompletion, - OpenAIEmbeddingData, OpenAIMessageParam, OpenAIResponseFormatParam, SamplingParams, @@ -1402,28 +1399,3 @@ def prepare_openai_embeddings_params( params["user"] = user return params - - -def b64_encode_openai_embeddings_response( - response_data: dict, encoding_format: str | None = "float" -) -> list[OpenAIEmbeddingData]: - """ - Process the OpenAI embeddings response to encode the embeddings in base64 format if specified. - """ - data = [] - for i, embedding_data in enumerate(response_data): - if encoding_format == "base64": - byte_array = bytearray() - for embedding_value in embedding_data.embedding: - byte_array.extend(struct.pack("f", float(embedding_value))) - - response_embedding = base64.b64encode(byte_array).decode("utf-8") - else: - response_embedding = embedding_data.embedding - data.append( - OpenAIEmbeddingData( - embedding=response_embedding, - index=i, - ) - ) - return data diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py index d30b5b12a..55a6793c2 100644 --- a/tests/unit/providers/inference/test_inference_client_caching.py +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -18,6 +18,8 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter +from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig +from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter @pytest.mark.parametrize( @@ -58,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} ): assert inference_adapter.client.api_key == api_key + + +@pytest.mark.parametrize( + "config_cls,adapter_cls,provider_data_validator", + [ + ( + WatsonXConfig, + WatsonXInferenceAdapter, + "llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator", + ), + ], +) +def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str): + """Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the + assumption that there is an OpenAI-compatible client object.""" + + inference_adapter = adapter_cls(config=config_cls()) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator + + for api_key in ["test1", "test2"]: + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} + ): + assert inference_adapter.get_api_key() == api_key