From 999c28e8098ed411d1c70223099128817e42c698 Mon Sep 17 00:00:00 2001 From: Bill Murdock Date: Fri, 3 Oct 2025 15:07:15 -0400 Subject: [PATCH] fix: Update Watsonx provider to use LiteLLM mixin and list all models Signed-off-by: Bill Murdock --- llama_stack/core/routers/inference.py | 2 +- .../remote/inference/watsonx/__init__.py | 11 +- .../remote/inference/watsonx/config.py | 4 +- .../remote/inference/watsonx/models.py | 47 --- .../remote/inference/watsonx/watsonx.py | 309 +++++------------- .../test_inference_client_caching.py | 20 ++ 6 files changed, 109 insertions(+), 284 deletions(-) delete mode 100644 llama_stack/providers/remote/inference/watsonx/models.py diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index c4338e614..847f6a2d2 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -611,7 +611,7 @@ class InferenceRouter(Inference): completion_text += "".join(choice_data["content_parts"]) # Add metrics to the chunk - if self.telemetry and chunk.usage: + if self.telemetry and hasattr(chunk, "usage") and chunk.usage: metrics = self._construct_metrics( prompt_tokens=chunk.usage.prompt_tokens, completion_tokens=chunk.usage.completion_tokens, diff --git a/llama_stack/providers/remote/inference/watsonx/__init__.py b/llama_stack/providers/remote/inference/watsonx/__init__.py index e59e873b6..35e74a720 100644 --- a/llama_stack/providers/remote/inference/watsonx/__init__.py +++ b/llama_stack/providers/remote/inference/watsonx/__init__.py @@ -4,19 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.inference import Inference - from .config import WatsonXConfig -async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference: - # import dynamically so `llama stack build` does not fail due to missing dependencies +async def get_adapter_impl(config: WatsonXConfig, _deps): + # import dynamically so the import is used only when it is needed from .watsonx import WatsonXInferenceAdapter - if not isinstance(config, WatsonXConfig): - raise RuntimeError(f"Unexpected config type: {type(config)}") adapter = WatsonXInferenceAdapter(config) return adapter - - -__all__ = ["get_adapter_impl", "WatsonXConfig"] diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 4bc0173c4..8417288f7 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -27,11 +27,11 @@ class WatsonXConfig(RemoteInferenceProviderConfig): ) api_key: SecretStr | None = Field( default_factory=lambda: os.getenv("WATSONX_API_KEY"), - description="The watsonx API key", + description="The watsonx.ai API key", ) project_id: str | None = Field( default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"), - description="The Project ID key", + description="The watsonx.ai project ID", ) timeout: int = Field( default=60, diff --git a/llama_stack/providers/remote/inference/watsonx/models.py b/llama_stack/providers/remote/inference/watsonx/models.py deleted file mode 100644 index d98f0510a..000000000 --- a/llama_stack/providers/remote/inference/watsonx/models.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry - -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "meta-llama/llama-3-3-70b-instruct", - CoreModelId.llama3_3_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-2-13b-chat", - CoreModelId.llama2_13b.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-1-70b-instruct", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-1-8b-instruct", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-11b-vision-instruct", - CoreModelId.llama3_2_11b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-1b-instruct", - CoreModelId.llama3_2_1b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-3b-instruct", - CoreModelId.llama3_2_3b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-90b-vision-instruct", - CoreModelId.llama3_2_90b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-guard-3-11b-vision", - CoreModelId.llama_guard_3_11b_vision.value, - ), -] diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 0557aff5f..9584789e3 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,246 +4,105 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncGenerator, AsyncIterator +import asyncio from typing import Any -from ibm_watsonx_ai.foundation_models import Model -from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams -from openai import AsyncOpenAI +import requests -from llama_stack.apis.inference import ( - ChatCompletionRequest, - CompletionRequest, - GreedySamplingStrategy, - Inference, - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAICompletion, - OpenAIEmbeddingsResponse, - OpenAIMessageParam, - OpenAIResponseFormatParam, - TopKSamplingStrategy, - TopPSamplingStrategy, -) -from llama_stack.log import get_logger -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper -from llama_stack.providers.utils.inference.openai_compat import ( - prepare_openai_completion_params, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_prompt, - completion_request_to_prompt, - request_has_media, -) - -from . import WatsonXConfig -from .models import MODEL_ENTRIES - -logger = get_logger(name=__name__, category="inference::watsonx") +from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.apis.models import Model +from llama_stack.apis.models.models import ModelType +from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -# Note on structured output -# WatsonX returns responses with a json embedded into a string. -# Examples: +class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): + _config: WatsonXConfig + __provider_id__: str = "watsonx" -# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n -# "first_name": "Michael",\n "last_name": "Jordan",\n'...) -# Not even a valid JSON, but we can still extract the JSON from the content + def __init__(self, config: WatsonXConfig): + LiteLLMOpenAIMixin.__init__( + self, + litellm_provider_name="watsonx", + api_key_from_config=config.api_key.get_secret_value(), + provider_data_api_key_field="watsonx_api_key", + ) + self.available_models = None + self.config = config -# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan", -# "year_born": "1963", "year_retired": "2003"\\}}$') -# Find the start of the boxed content + # get_api_key = LiteLLMOpenAIMixin.get_api_key + def get_base_url(self) -> str: + return self.config.url -class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): - def __init__(self, config: WatsonXConfig) -> None: - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) + async def initialize(self): + await super().initialize() - logger.info(f"Initializing watsonx InferenceAdapter({config.url})...") - self._config = config - self._openai_client: AsyncOpenAI | None = None + async def shutdown(self): + await super().shutdown() - self._project_id = self._config.project_id + async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: + # Get base parameters from parent + params = await super()._get_params(request) - async def initialize(self) -> None: - pass - - async def shutdown(self) -> None: - pass - - def _get_client(self, model_id) -> Model: - config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None - config_url = self._config.url - project_id = self._config.project_id - credentials = {"url": config_url, "apikey": config_api_key} - - return Model(model_id=model_id, credentials=credentials, project_id=project_id) - - def _get_openai_client(self) -> AsyncOpenAI: - if not self._openai_client: - self._openai_client = AsyncOpenAI( - base_url=f"{self._config.url}/openai/v1", - api_key=self._config.api_key, - ) - return self._openai_client - - async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: - input_dict = {"params": {}} - media_present = request_has_media(request) - llama_model = self.get_llama_model(request.model) - if isinstance(request, ChatCompletionRequest): - input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) - else: - assert not media_present, "Together does not support media for Completion requests" - input_dict["prompt"] = await completion_request_to_prompt(request) - if request.sampling_params: - if request.sampling_params.strategy: - input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type - if request.sampling_params.max_tokens: - input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens - if request.sampling_params.repetition_penalty: - input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty - - if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): - input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p - input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature - if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): - input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k - if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): - input_dict["params"][GenParams.TEMPERATURE] = 0.0 - - input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"] - - params = { - **input_dict, - } + # Add watsonx.ai specific parameters + params["project_id"] = self.config.project_id + params["time_limit"] = self.config.timeout return params - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() + async def check_model_availability(self, model): + return True - async def openai_completion( - self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, - ) -> OpenAICompletion: - model_obj = await self.model_store.get_model(model) - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - ) - return await self._get_openai_client().completions.create(**params) # type: ignore + async def list_models(self) -> list[Model] | None: + models = [] + for model_spec in self._get_model_specs(): + models.append( + Model( + identifier=model_spec["model_id"], + provider_resource_id=f"{self.__provider_id__}/{model_spec['model_id']}", + provider_id=self.__provider_id__, + metadata={}, + model_type=ModelType.llm, + ) + ) + return models - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - model_obj = await self.model_store.get_model(model) - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) - if params.get("stream", False): - return self._stream_openai_chat_completion(params) - return await self._get_openai_client().chat.completions.create(**params) # type: ignore + # LiteLLM provides methods to list models for many providers, but not for watsonx.ai. + # So we need to implement our own method to list models by calling the watsonx.ai API. + def _get_model_specs(self) -> list[dict[str, Any]]: + """ + Retrieves foundation model specifications from the watsonx.ai API. + """ + url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25" + headers = { + # Note that there is no authorization header. Listing models does not require authentication. + "Content-Type": "application/json", + } - async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator: - # watsonx.ai sometimes adds usage data to the stream - include_usage = False - if params.get("stream_options", None): - include_usage = params["stream_options"].get("include_usage", False) - stream = await self._get_openai_client().chat.completions.create(**params) + response = requests.get(url, headers=headers) - seen_finish_reason = False - async for chunk in stream: - # Final usage chunk with no choices that the user didn't request, so discard - if not include_usage and seen_finish_reason and len(chunk.choices) == 0: - break - yield chunk - for choice in chunk.choices: - if choice.finish_reason: - seen_finish_reason = True - break + # --- Process the Response --- + # Raise an exception for bad status codes (4xx or 5xx) + response.raise_for_status() + + # If the request is successful, parse and return the JSON response. + # The response should contain a list of model specifications + response_data = response.json() + if "resources" not in response_data: + raise ValueError("Resources not found in response") + return response_data["resources"] + + +# TO DO: Delete the test main method. +if __name__ == "__main__": + config = WatsonXConfig(url="https://us-south.ml.cloud.ibm.com", api_key="xxx", project_id="xxx", timeout=60) + adapter = WatsonXInferenceAdapter(config) + model_specs = adapter._get_model_specs() + models = asyncio.run(adapter.list_models()) + for model in models: + print(model.identifier) + print(model.provider_resource_id) + print(model.provider_id) + print(model.metadata) + print(model.model_type) + print("--------------------------------") diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py index f4b3201e9..ee71149de 100644 --- a/tests/unit/providers/inference/test_inference_client_caching.py +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -16,6 +16,8 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter +from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig +from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter def test_groq_provider_openai_client_caching(): @@ -36,6 +38,24 @@ def test_groq_provider_openai_client_caching(): assert inference_adapter.client.api_key == api_key +def test_watsonx_provider_openai_client_caching(): + """Ensure the WatsonX provider does not cache api keys across client requests""" + + config = WatsonXConfig() + inference_adapter = WatsonXInferenceAdapter(config) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = ( + "llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator" + ) + + for api_key in ["test1", "test2"]: + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} + ): + assert inference_adapter.client.api_key == api_key + + def test_openai_provider_openai_client_caching(): """Ensure the OpenAI provider does not cache api keys across client requests"""