diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index f89565892..6033c3186 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -271,7 +271,7 @@ Available Models: pip_packages=["litellm"], module="llama_stack.providers.remote.inference.watsonx", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", - provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", + provider_data_validator="llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator", description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", ), RemoteProviderSpec( diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 022dc5ee7..8d8df13b4 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -7,18 +7,18 @@ import os from typing import Any -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type class WatsonXProviderDataValidator(BaseModel): - model_config = ConfigDict( - from_attributes=True, - extra="forbid", + watsonx_project_id: str | None = Field( + default=None, + description="IBM WatsonX project ID", ) - watsonx_api_key: str | None + watsonx_api_key: str | None = None @json_schema_type diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 654d61f34..2c051719b 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,42 +4,259 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import AsyncIterator from typing import Any +import litellm import requests -from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequestWithExtraBody, + OpenAIChatCompletionUsage, + OpenAICompletion, + OpenAICompletionRequestWithExtraBody, + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIEmbeddingsResponse, +) from llama_stack.apis.models import Model from llama_stack.apis.models.models import ModelType +from llama_stack.log import get_logger from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params +from llama_stack.providers.utils.telemetry.tracing import get_current_span + +logger = get_logger(name=__name__, category="providers::remote::watsonx") class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): _model_cache: dict[str, Model] = {} + provider_data_api_key_field: str = "watsonx_api_key" + def __init__(self, config: WatsonXConfig): + self.available_models = None + self.config = config + api_key = config.auth_credential.get_secret_value() if config.auth_credential else None LiteLLMOpenAIMixin.__init__( self, litellm_provider_name="watsonx", - api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None, + api_key_from_config=api_key, provider_data_api_key_field="watsonx_api_key", + openai_compat_api_base=self.get_base_url(), + ) + + async def openai_chat_completion( + self, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + """ + Override parent method to add timeout and inject usage object when missing. + This works around a LiteLLM defect where usage block is sometimes dropped. + """ + + # Add usage tracking for streaming when telemetry is active + stream_options = params.stream_options + if params.stream and get_current_span() is not None: + if stream_options is None: + stream_options = {"include_usage": True} + elif "include_usage" not in stream_options: + stream_options = {**stream_options, "include_usage": True} + + model_obj = await self.model_store.get_model(params.model) + + request_params = await prepare_openai_completion_params( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + messages=params.messages, + frequency_penalty=params.frequency_penalty, + function_call=params.function_call, + functions=params.functions, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_completion_tokens=params.max_completion_tokens, + max_tokens=params.max_tokens, + n=params.n, + parallel_tool_calls=params.parallel_tool_calls, + presence_penalty=params.presence_penalty, + response_format=params.response_format, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=stream_options, + temperature=params.temperature, + tool_choice=params.tool_choice, + tools=params.tools, + top_logprobs=params.top_logprobs, + top_p=params.top_p, + user=params.user, + api_key=self.get_api_key(), + api_base=self.api_base, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + + result = await litellm.acompletion(**request_params) + + # If not streaming, check and inject usage if missing + if not params.stream: + # Use getattr to safely handle cases where usage attribute might not exist + if getattr(result, "usage", None) is None: + # Create usage object with zeros + usage_obj = OpenAIChatCompletionUsage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + # Use model_copy to create a new response with the usage injected + result = result.model_copy(update={"usage": usage_obj}) + return result + + # For streaming, wrap the iterator to normalize chunks + return self._normalize_stream(result) + + def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk: + """ + Normalize a chunk to ensure it has all expected attributes. + This works around LiteLLM not always including all expected attributes. + """ + # Ensure chunk has usage attribute with zeros if missing + if not hasattr(chunk, "usage") or chunk.usage is None: + usage_obj = OpenAIChatCompletionUsage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + chunk = chunk.model_copy(update={"usage": usage_obj}) + + # Ensure all delta objects in choices have expected attributes + if hasattr(chunk, "choices") and chunk.choices: + normalized_choices = [] + for choice in chunk.choices: + if hasattr(choice, "delta") and choice.delta: + delta = choice.delta + # Build update dict for missing attributes + delta_updates = {} + if not hasattr(delta, "refusal"): + delta_updates["refusal"] = None + if not hasattr(delta, "reasoning_content"): + delta_updates["reasoning_content"] = None + + # If we need to update delta, create a new choice with updated delta + if delta_updates: + new_delta = delta.model_copy(update=delta_updates) + new_choice = choice.model_copy(update={"delta": new_delta}) + normalized_choices.append(new_choice) + else: + normalized_choices.append(choice) + else: + normalized_choices.append(choice) + + # If we modified any choices, create a new chunk with updated choices + if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))): + chunk = chunk.model_copy(update={"choices": normalized_choices}) + + return chunk + + async def _normalize_stream( + self, stream: AsyncIterator[OpenAIChatCompletionChunk] + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """ + Normalize all chunks in the stream to ensure they have expected attributes. + This works around LiteLLM sometimes not including expected attributes. + """ + try: + async for chunk in stream: + # Normalize and yield each chunk immediately + yield self._normalize_chunk(chunk) + except Exception as e: + logger.error(f"Error normalizing stream: {e}", exc_info=True) + raise + + async def openai_completion( + self, + params: OpenAICompletionRequestWithExtraBody, + ) -> OpenAICompletion: + """ + Override parent method to add watsonx-specific parameters. + """ + from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params + + model_obj = await self.model_store.get_model(params.model) + + request_params = await prepare_openai_completion_params( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + prompt=params.prompt, + best_of=params.best_of, + echo=params.echo, + frequency_penalty=params.frequency_penalty, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_tokens=params.max_tokens, + n=params.n, + presence_penalty=params.presence_penalty, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=params.stream_options, + temperature=params.temperature, + top_p=params.top_p, + user=params.user, + suffix=params.suffix, + api_key=self.get_api_key(), + api_base=self.api_base, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + return await litellm.atext_completion(**request_params) + + async def openai_embeddings( + self, + params: OpenAIEmbeddingsRequestWithExtraBody, + ) -> OpenAIEmbeddingsResponse: + """ + Override parent method to add watsonx-specific parameters. + """ + model_obj = await self.model_store.get_model(params.model) + + # Convert input to list if it's a string + input_list = [params.input] if isinstance(params.input, str) else params.input + + # Call litellm embedding function with watsonx-specific parameters + response = litellm.embedding( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + input=input_list, + api_key=self.get_api_key(), + api_base=self.api_base, + dimensions=params.dimensions, + # These are watsonx-specific parameters + timeout=self.config.timeout, + project_id=self.config.project_id, + ) + + # Convert response to OpenAI format + from llama_stack.apis.inference import OpenAIEmbeddingUsage + from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response + + data = b64_encode_openai_embeddings_response(response.data, params.encoding_format) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response["usage"]["prompt_tokens"], + total_tokens=response["usage"]["total_tokens"], + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=model_obj.provider_resource_id, + usage=usage, ) - self.available_models = None - self.config = config def get_base_url(self) -> str: return self.config.url - async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: - # Get base parameters from parent - params = await super()._get_params(request) - - # Add watsonx.ai specific parameters - params["project_id"] = self.config.project_id - params["time_limit"] = self.config.timeout - return params - # Copied from OpenAIMixin async def check_model_availability(self, model: str) -> bool: """ diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 3f0cffb2d..65f773889 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -58,7 +58,6 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) # does not work with the specified model, gpt-5-mini. Please choose different model and try # again. You can learn more about which models can be used with each operation here: # https://go.microsoft.com/fwlink/?linkid=2197993.'}}"} - "remote::watsonx", # return 404 when hitting the /openai/v1 endpoint "remote::llama-openai-compat", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") @@ -68,6 +67,7 @@ def skip_if_doesnt_support_completions_logprobs(client_with_models, model_id): provider_type = provider_from_model(client_with_models, model_id).provider_type if provider_type in ( "remote::ollama", # logprobs is ignored + "remote::watsonx", ): pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions logprobs.") @@ -110,6 +110,7 @@ def skip_if_doesnt_support_n(client_with_models, model_id): # Error code 400 - {'message': '"n" > 1 is not currently supported', 'type': 'invalid_request_error', 'param': 'n', 'code': 'wrong_api_format'} "remote::cerebras", "remote::databricks", # Bad request: parameter "n" must be equal to 1 for streaming mode + "remote::watsonx", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.") @@ -124,7 +125,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode "remote::databricks", "remote::cerebras", "remote::runpod", - "remote::watsonx", # watsonx returns 404 when hitting the /openai/v1 endpoint ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.") @@ -508,6 +508,12 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi assert "hello world" in normalized_content +def skip_if_doesnt_support_completions_stop_sequence(client_with_models, model_id): + provider_type = provider_from_model(client_with_models, model_id).provider_type + if provider_type in ("remote::watsonx",): # openai.BadRequestError: Error code: 400 + pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions stop sequence.") + + @pytest.mark.parametrize( "test_case", [ @@ -516,6 +522,7 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi ) def test_openai_completion_stop_sequence(client_with_models, openai_client, text_model_id, test_case): skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) + skip_if_doesnt_support_completions_stop_sequence(client_with_models, text_model_id) tc = TestCase(test_case) diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index 84e92706a..0c1d4d08e 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -50,11 +50,15 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id): def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_id): provider = provider_from_model(client_with_models, model_id) - if provider.provider_type in ( - "remote::together", # returns 400 - "inline::sentence-transformers", - # Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'} - "remote::databricks", + if ( + provider.provider_type + in ( + "remote::together", # returns 400 + "inline::sentence-transformers", + # Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'} + "remote::databricks", + "remote::watsonx", # openai.BadRequestError: Error code: 400 - {'detail': "litellm.UnsupportedParamsError: watsonx does not support parameters: {'dimensions': 384} + ) ): pytest.skip( f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."