From 0dbf79c328d8444cd9fa90891be9a4e9c36588df Mon Sep 17 00:00:00 2001 From: Cesare Pompeiano <195810094+are-ces@users.noreply.github.com> Date: Tue, 14 Oct 2025 14:52:32 +0200 Subject: [PATCH] fix: Fixed WatsonX remote inference provider (#3801) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This PR fixes issues with the WatsonX provider so it works correctly with LiteLLM. The main problem was that WatsonX requests failed because the provider data validator didn’t properly handle the API key and project ID. This was fixed by updating the WatsonXProviderDataValidator and ensuring the provider data is loaded correctly. The openai_chat_completion method was also updated to match the behavior of other providers while adding WatsonX-specific fields like project_id. It still calls await super().openai_chat_completion.__func__(self, params) to keep the existing setup and tracing logic. After these changes, WatsonX requests now run correctly. ## Test Plan The changes were tested by running chat completion requests and confirming that credentials and project parameters are passed correctly. I have tested with my WatsonX credentials, by using the cli with `uv run llama-stack-client inference chat-completion --session` --------- Signed-off-by: Sébastien Han Co-authored-by: Sébastien Han --- llama_stack/providers/registry/inference.py | 2 +- .../remote/inference/watsonx/config.py | 10 +- .../remote/inference/watsonx/watsonx.py | 243 +++++++++++++++++- .../inference/test_openai_completion.py | 11 +- .../inference/test_openai_embeddings.py | 14 +- 5 files changed, 254 insertions(+), 26 deletions(-) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index f89565892..6033c3186 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -271,7 +271,7 @@ Available Models: pip_packages=["litellm"], module="llama_stack.providers.remote.inference.watsonx", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", - provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", + provider_data_validator="llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator", description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", ), RemoteProviderSpec( diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 022dc5ee7..8d8df13b4 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -7,18 +7,18 @@ import os from typing import Any -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type class WatsonXProviderDataValidator(BaseModel): - model_config = ConfigDict( - from_attributes=True, - extra="forbid", + watsonx_project_id: str | None = Field( + default=None, + description="IBM WatsonX project ID", ) - watsonx_api_key: str | None + watsonx_api_key: str | None = None @json_schema_type diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 654d61f34..2c051719b 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,42 +4,259 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import AsyncIterator from typing import Any +import litellm import requests -from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequestWithExtraBody, + OpenAIChatCompletionUsage, + OpenAICompletion, + OpenAICompletionRequestWithExtraBody, + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIEmbeddingsResponse, +) from llama_stack.apis.models import Model from llama_stack.apis.models.models import ModelType +from llama_stack.log import get_logger from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params +from llama_stack.providers.utils.telemetry.tracing import get_current_span + +logger = get_logger(name=__name__, category="providers::remote::watsonx") class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): _model_cache: dict[str, Model] = {} + provider_data_api_key_field: str = "watsonx_api_key" + def __init__(self, config: WatsonXConfig): + self.available_models = None + self.config = config + api_key = config.auth_credential.get_secret_value() if config.auth_credential else None LiteLLMOpenAIMixin.__init__( self, litellm_provider_name="watsonx", - api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None, + api_key_from_config=api_key, provider_data_api_key_field="watsonx_api_key", + openai_compat_api_base=self.get_base_url(), + ) + + async def openai_chat_completion( + self, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + """ + Override parent method to add timeout and inject usage object when missing. + This works around a LiteLLM defect where usage block is sometimes dropped. + """ + + # Add usage tracking for streaming when telemetry is active + stream_options = params.stream_options + if params.stream and get_current_span() is not None: + if stream_options is None: + stream_options = {"include_usage": True} + elif "include_usage" not in stream_options: + stream_options = {**stream_options, "include_usage": True} + + model_obj = await self.model_store.get_model(params.model) + + request_params = await prepare_openai_completion_params( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + messages=params.messages, + frequency_penalty=params.frequency_penalty, + function_call=params.function_call, + functions=params.functions, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_completion_tokens=params.max_completion_tokens, + max_tokens=params.max_tokens, + n=params.n, + parallel_tool_calls=params.parallel_tool_calls, + presence_penalty=params.presence_penalty, + response_format=params.response_format, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=stream_options, + temperature=params.temperature, + tool_choice=params.tool_choice, + tools=params.tools, + top_logprobs=params.top_logprobs, + top_p=params.top_p, + user=params.user, + api_key=self.get_api_key(), + api_base=self.api_base, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + + result = await litellm.acompletion(**request_params) + + # If not streaming, check and inject usage if missing + if not params.stream: + # Use getattr to safely handle cases where usage attribute might not exist + if getattr(result, "usage", None) is None: + # Create usage object with zeros + usage_obj = OpenAIChatCompletionUsage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + # Use model_copy to create a new response with the usage injected + result = result.model_copy(update={"usage": usage_obj}) + return result + + # For streaming, wrap the iterator to normalize chunks + return self._normalize_stream(result) + + def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk: + """ + Normalize a chunk to ensure it has all expected attributes. + This works around LiteLLM not always including all expected attributes. + """ + # Ensure chunk has usage attribute with zeros if missing + if not hasattr(chunk, "usage") or chunk.usage is None: + usage_obj = OpenAIChatCompletionUsage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + chunk = chunk.model_copy(update={"usage": usage_obj}) + + # Ensure all delta objects in choices have expected attributes + if hasattr(chunk, "choices") and chunk.choices: + normalized_choices = [] + for choice in chunk.choices: + if hasattr(choice, "delta") and choice.delta: + delta = choice.delta + # Build update dict for missing attributes + delta_updates = {} + if not hasattr(delta, "refusal"): + delta_updates["refusal"] = None + if not hasattr(delta, "reasoning_content"): + delta_updates["reasoning_content"] = None + + # If we need to update delta, create a new choice with updated delta + if delta_updates: + new_delta = delta.model_copy(update=delta_updates) + new_choice = choice.model_copy(update={"delta": new_delta}) + normalized_choices.append(new_choice) + else: + normalized_choices.append(choice) + else: + normalized_choices.append(choice) + + # If we modified any choices, create a new chunk with updated choices + if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))): + chunk = chunk.model_copy(update={"choices": normalized_choices}) + + return chunk + + async def _normalize_stream( + self, stream: AsyncIterator[OpenAIChatCompletionChunk] + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """ + Normalize all chunks in the stream to ensure they have expected attributes. + This works around LiteLLM sometimes not including expected attributes. + """ + try: + async for chunk in stream: + # Normalize and yield each chunk immediately + yield self._normalize_chunk(chunk) + except Exception as e: + logger.error(f"Error normalizing stream: {e}", exc_info=True) + raise + + async def openai_completion( + self, + params: OpenAICompletionRequestWithExtraBody, + ) -> OpenAICompletion: + """ + Override parent method to add watsonx-specific parameters. + """ + from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params + + model_obj = await self.model_store.get_model(params.model) + + request_params = await prepare_openai_completion_params( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + prompt=params.prompt, + best_of=params.best_of, + echo=params.echo, + frequency_penalty=params.frequency_penalty, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_tokens=params.max_tokens, + n=params.n, + presence_penalty=params.presence_penalty, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=params.stream_options, + temperature=params.temperature, + top_p=params.top_p, + user=params.user, + suffix=params.suffix, + api_key=self.get_api_key(), + api_base=self.api_base, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + return await litellm.atext_completion(**request_params) + + async def openai_embeddings( + self, + params: OpenAIEmbeddingsRequestWithExtraBody, + ) -> OpenAIEmbeddingsResponse: + """ + Override parent method to add watsonx-specific parameters. + """ + model_obj = await self.model_store.get_model(params.model) + + # Convert input to list if it's a string + input_list = [params.input] if isinstance(params.input, str) else params.input + + # Call litellm embedding function with watsonx-specific parameters + response = litellm.embedding( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + input=input_list, + api_key=self.get_api_key(), + api_base=self.api_base, + dimensions=params.dimensions, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + + # Convert response to OpenAI format + from llama_stack.apis.inference import OpenAIEmbeddingUsage + from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response + + data = b64_encode_openai_embeddings_response(response.data, params.encoding_format) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response["usage"]["prompt_tokens"], + total_tokens=response["usage"]["total_tokens"], + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=model_obj.provider_resource_id, + usage=usage, ) - self.available_models = None - self.config = config def get_base_url(self) -> str: return self.config.url - async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: - # Get base parameters from parent - params = await super()._get_params(request) - - # Add watsonx.ai specific parameters - params["project_id"] = self.config.project_id - params["time_limit"] = self.config.timeout - return params - # Copied from OpenAIMixin async def check_model_availability(self, model: str) -> bool: """ diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 3f0cffb2d..65f773889 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -58,7 +58,6 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) # does not work with the specified model, gpt-5-mini. Please choose different model and try # again. You can learn more about which models can be used with each operation here: # https://go.microsoft.com/fwlink/?linkid=2197993.'}}"} - "remote::watsonx", # return 404 when hitting the /openai/v1 endpoint "remote::llama-openai-compat", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") @@ -68,6 +67,7 @@ def skip_if_doesnt_support_completions_logprobs(client_with_models, model_id): provider_type = provider_from_model(client_with_models, model_id).provider_type if provider_type in ( "remote::ollama", # logprobs is ignored + "remote::watsonx", ): pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions logprobs.") @@ -110,6 +110,7 @@ def skip_if_doesnt_support_n(client_with_models, model_id): # Error code 400 - {'message': '"n" > 1 is not currently supported', 'type': 'invalid_request_error', 'param': 'n', 'code': 'wrong_api_format'} "remote::cerebras", "remote::databricks", # Bad request: parameter "n" must be equal to 1 for streaming mode + "remote::watsonx", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.") @@ -124,7 +125,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode "remote::databricks", "remote::cerebras", "remote::runpod", - "remote::watsonx", # watsonx returns 404 when hitting the /openai/v1 endpoint ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.") @@ -508,6 +508,12 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi assert "hello world" in normalized_content +def skip_if_doesnt_support_completions_stop_sequence(client_with_models, model_id): + provider_type = provider_from_model(client_with_models, model_id).provider_type + if provider_type in ("remote::watsonx",): # openai.BadRequestError: Error code: 400 + pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions stop sequence.") + + @pytest.mark.parametrize( "test_case", [ @@ -516,6 +522,7 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi ) def test_openai_completion_stop_sequence(client_with_models, openai_client, text_model_id, test_case): skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) + skip_if_doesnt_support_completions_stop_sequence(client_with_models, text_model_id) tc = TestCase(test_case) diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index 84e92706a..0c1d4d08e 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -50,11 +50,15 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id): def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_id): provider = provider_from_model(client_with_models, model_id) - if provider.provider_type in ( - "remote::together", # returns 400 - "inline::sentence-transformers", - # Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'} - "remote::databricks", + if ( + provider.provider_type + in ( + "remote::together", # returns 400 + "inline::sentence-transformers", + # Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'} + "remote::databricks", + "remote::watsonx", # openai.BadRequestError: Error code: 400 - {'detail': "litellm.UnsupportedParamsError: watsonx does not support parameters: {'dimensions': 384} + ) ): pytest.skip( f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."