From effe7609a92e39693b40674a62194e11ab4a2530 Mon Sep 17 00:00:00 2001 From: are-ces <195810094+are-ces@users.noreply.github.com> Date: Mon, 13 Oct 2025 14:32:29 +0200 Subject: [PATCH] Fixed WatsonX bugs --- llama_stack/providers/registry/inference.py | 2 +- .../remote/inference/watsonx/config.py | 10 +- .../remote/inference/watsonx/watsonx.py | 243 +++++++++++++++++- 3 files changed, 236 insertions(+), 19 deletions(-) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index f89565892..6033c3186 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -271,7 +271,7 @@ Available Models: pip_packages=["litellm"], module="llama_stack.providers.remote.inference.watsonx", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", - provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", + provider_data_validator="llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator", description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", ), RemoteProviderSpec( diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 022dc5ee7..8d8df13b4 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -7,18 +7,18 @@ import os from typing import Any -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type class WatsonXProviderDataValidator(BaseModel): - model_config = ConfigDict( - from_attributes=True, - extra="forbid", + watsonx_project_id: str | None = Field( + default=None, + description="IBM WatsonX project ID", ) - watsonx_api_key: str | None + watsonx_api_key: str | None = None @json_schema_type diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 654d61f34..2c051719b 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,42 +4,259 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import AsyncIterator from typing import Any +import litellm import requests -from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequestWithExtraBody, + OpenAIChatCompletionUsage, + OpenAICompletion, + OpenAICompletionRequestWithExtraBody, + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIEmbeddingsResponse, +) from llama_stack.apis.models import Model from llama_stack.apis.models.models import ModelType +from llama_stack.log import get_logger from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params +from llama_stack.providers.utils.telemetry.tracing import get_current_span + +logger = get_logger(name=__name__, category="providers::remote::watsonx") class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): _model_cache: dict[str, Model] = {} + provider_data_api_key_field: str = "watsonx_api_key" + def __init__(self, config: WatsonXConfig): + self.available_models = None + self.config = config + api_key = config.auth_credential.get_secret_value() if config.auth_credential else None LiteLLMOpenAIMixin.__init__( self, litellm_provider_name="watsonx", - api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None, + api_key_from_config=api_key, provider_data_api_key_field="watsonx_api_key", + openai_compat_api_base=self.get_base_url(), + ) + + async def openai_chat_completion( + self, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + """ + Override parent method to add timeout and inject usage object when missing. + This works around a LiteLLM defect where usage block is sometimes dropped. + """ + + # Add usage tracking for streaming when telemetry is active + stream_options = params.stream_options + if params.stream and get_current_span() is not None: + if stream_options is None: + stream_options = {"include_usage": True} + elif "include_usage" not in stream_options: + stream_options = {**stream_options, "include_usage": True} + + model_obj = await self.model_store.get_model(params.model) + + request_params = await prepare_openai_completion_params( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + messages=params.messages, + frequency_penalty=params.frequency_penalty, + function_call=params.function_call, + functions=params.functions, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_completion_tokens=params.max_completion_tokens, + max_tokens=params.max_tokens, + n=params.n, + parallel_tool_calls=params.parallel_tool_calls, + presence_penalty=params.presence_penalty, + response_format=params.response_format, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=stream_options, + temperature=params.temperature, + tool_choice=params.tool_choice, + tools=params.tools, + top_logprobs=params.top_logprobs, + top_p=params.top_p, + user=params.user, + api_key=self.get_api_key(), + api_base=self.api_base, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + + result = await litellm.acompletion(**request_params) + + # If not streaming, check and inject usage if missing + if not params.stream: + # Use getattr to safely handle cases where usage attribute might not exist + if getattr(result, "usage", None) is None: + # Create usage object with zeros + usage_obj = OpenAIChatCompletionUsage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + # Use model_copy to create a new response with the usage injected + result = result.model_copy(update={"usage": usage_obj}) + return result + + # For streaming, wrap the iterator to normalize chunks + return self._normalize_stream(result) + + def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk: + """ + Normalize a chunk to ensure it has all expected attributes. + This works around LiteLLM not always including all expected attributes. + """ + # Ensure chunk has usage attribute with zeros if missing + if not hasattr(chunk, "usage") or chunk.usage is None: + usage_obj = OpenAIChatCompletionUsage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + chunk = chunk.model_copy(update={"usage": usage_obj}) + + # Ensure all delta objects in choices have expected attributes + if hasattr(chunk, "choices") and chunk.choices: + normalized_choices = [] + for choice in chunk.choices: + if hasattr(choice, "delta") and choice.delta: + delta = choice.delta + # Build update dict for missing attributes + delta_updates = {} + if not hasattr(delta, "refusal"): + delta_updates["refusal"] = None + if not hasattr(delta, "reasoning_content"): + delta_updates["reasoning_content"] = None + + # If we need to update delta, create a new choice with updated delta + if delta_updates: + new_delta = delta.model_copy(update=delta_updates) + new_choice = choice.model_copy(update={"delta": new_delta}) + normalized_choices.append(new_choice) + else: + normalized_choices.append(choice) + else: + normalized_choices.append(choice) + + # If we modified any choices, create a new chunk with updated choices + if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))): + chunk = chunk.model_copy(update={"choices": normalized_choices}) + + return chunk + + async def _normalize_stream( + self, stream: AsyncIterator[OpenAIChatCompletionChunk] + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """ + Normalize all chunks in the stream to ensure they have expected attributes. + This works around LiteLLM sometimes not including expected attributes. + """ + try: + async for chunk in stream: + # Normalize and yield each chunk immediately + yield self._normalize_chunk(chunk) + except Exception as e: + logger.error(f"Error normalizing stream: {e}", exc_info=True) + raise + + async def openai_completion( + self, + params: OpenAICompletionRequestWithExtraBody, + ) -> OpenAICompletion: + """ + Override parent method to add watsonx-specific parameters. + """ + from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params + + model_obj = await self.model_store.get_model(params.model) + + request_params = await prepare_openai_completion_params( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + prompt=params.prompt, + best_of=params.best_of, + echo=params.echo, + frequency_penalty=params.frequency_penalty, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_tokens=params.max_tokens, + n=params.n, + presence_penalty=params.presence_penalty, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=params.stream_options, + temperature=params.temperature, + top_p=params.top_p, + user=params.user, + suffix=params.suffix, + api_key=self.get_api_key(), + api_base=self.api_base, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + return await litellm.atext_completion(**request_params) + + async def openai_embeddings( + self, + params: OpenAIEmbeddingsRequestWithExtraBody, + ) -> OpenAIEmbeddingsResponse: + """ + Override parent method to add watsonx-specific parameters. + """ + model_obj = await self.model_store.get_model(params.model) + + # Convert input to list if it's a string + input_list = [params.input] if isinstance(params.input, str) else params.input + + # Call litellm embedding function with watsonx-specific parameters + response = litellm.embedding( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + input=input_list, + api_key=self.get_api_key(), + api_base=self.api_base, + dimensions=params.dimensions, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + + # Convert response to OpenAI format + from llama_stack.apis.inference import OpenAIEmbeddingUsage + from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response + + data = b64_encode_openai_embeddings_response(response.data, params.encoding_format) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response["usage"]["prompt_tokens"], + total_tokens=response["usage"]["total_tokens"], + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=model_obj.provider_resource_id, + usage=usage, ) - self.available_models = None - self.config = config def get_base_url(self) -> str: return self.config.url - async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: - # Get base parameters from parent - params = await super()._get_params(request) - - # Add watsonx.ai specific parameters - params["project_id"] = self.config.project_id - params["time_limit"] = self.config.timeout - return params - # Copied from OpenAIMixin async def check_model_availability(self, model: str) -> bool: """