fix: Update Watsonx provider to use LiteLLM mixin and list all models

Signed-off-by: Bill Murdock <bmurdock@redhat.com>
This commit is contained in:
Bill Murdock 2025-10-03 15:07:15 -04:00
parent 9f6c658f2a
commit 999c28e809
6 changed files with 109 additions and 284 deletions

View file

@ -611,7 +611,7 @@ class InferenceRouter(Inference):
completion_text += "".join(choice_data["content_parts"]) completion_text += "".join(choice_data["content_parts"])
# Add metrics to the chunk # Add metrics to the chunk
if self.telemetry and chunk.usage: if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
metrics = self._construct_metrics( metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens, prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens, completion_tokens=chunk.usage.completion_tokens,

View file

@ -4,19 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import WatsonXConfig from .config import WatsonXConfig
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference: async def get_adapter_impl(config: WatsonXConfig, _deps):
# import dynamically so `llama stack build` does not fail due to missing dependencies # import dynamically so the import is used only when it is needed
from .watsonx import WatsonXInferenceAdapter from .watsonx import WatsonXInferenceAdapter
if not isinstance(config, WatsonXConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = WatsonXInferenceAdapter(config) adapter = WatsonXInferenceAdapter(config)
return adapter return adapter
__all__ = ["get_adapter_impl", "WatsonXConfig"]

View file

@ -27,11 +27,11 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
) )
api_key: SecretStr | None = Field( api_key: SecretStr | None = Field(
default_factory=lambda: os.getenv("WATSONX_API_KEY"), default_factory=lambda: os.getenv("WATSONX_API_KEY"),
description="The watsonx API key", description="The watsonx.ai API key",
) )
project_id: str | None = Field( project_id: str | None = Field(
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"), default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key", description="The watsonx.ai project ID",
) )
timeout: int = Field( timeout: int = Field(
default=60, default=60,

View file

@ -1,47 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/llama-3-3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-2-13b-chat",
CoreModelId.llama2_13b.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]

View file

@ -4,246 +4,105 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator import asyncio
from typing import Any from typing import Any
from ibm_watsonx_ai.foundation_models import Model import requests
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI
from llama_stack.apis.inference import ( from llama_stack.apis.inference import ChatCompletionRequest
ChatCompletionRequest, from llama_stack.apis.models import Model
CompletionRequest, from llama_stack.apis.models.models import ModelType
GreedySamplingStrategy, from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
Inference, from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
request_has_media,
)
from . import WatsonXConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::watsonx")
# Note on structured output class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
# WatsonX returns responses with a json embedded into a string. _config: WatsonXConfig
# Examples: __provider_id__: str = "watsonx"
# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n def __init__(self, config: WatsonXConfig):
# "first_name": "Michael",\n "last_name": "Jordan",\n'...) LiteLLMOpenAIMixin.__init__(
# Not even a valid JSON, but we can still extract the JSON from the content self,
litellm_provider_name="watsonx",
api_key_from_config=config.api_key.get_secret_value(),
provider_data_api_key_field="watsonx_api_key",
)
self.available_models = None
self.config = config
# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan", # get_api_key = LiteLLMOpenAIMixin.get_api_key
# "year_born": "1963", "year_retired": "2003"\\}}$')
# Find the start of the boxed content
def get_base_url(self) -> str:
return self.config.url
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): async def initialize(self):
def __init__(self, config: WatsonXConfig) -> None: await super().initialize()
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...") async def shutdown(self):
self._config = config await super().shutdown()
self._openai_client: AsyncOpenAI | None = None
self._project_id = self._config.project_id async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
async def initialize(self) -> None: # Add watsonx.ai specific parameters
pass params["project_id"] = self.config.project_id
params["time_limit"] = self.config.timeout
async def shutdown(self) -> None:
pass
def _get_client(self, model_id) -> Model:
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
config_url = self._config.url
project_id = self._config.project_id
credentials = {"url": config_url, "apikey": config_api_key}
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
self._openai_client = AsyncOpenAI(
base_url=f"{self._config.url}/openai/v1",
api_key=self._config.api_key,
)
return self._openai_client
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
input_dict = {"params": {}}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
if request.sampling_params:
if request.sampling_params.strategy:
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
if request.sampling_params.max_tokens:
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["params"][GenParams.TEMPERATURE] = 0.0
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
params = {
**input_dict,
}
return params return params
async def openai_embeddings( async def check_model_availability(self, model):
self, return True
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion( async def list_models(self) -> list[Model] | None:
self, models = []
model: str, for model_spec in self._get_model_specs():
prompt: str | list[str] | list[int] | list[list[int]], models.append(
best_of: int | None = None, Model(
echo: bool | None = None, identifier=model_spec["model_id"],
frequency_penalty: float | None = None, provider_resource_id=f"{self.__provider_id__}/{model_spec['model_id']}",
logit_bias: dict[str, float] | None = None, provider_id=self.__provider_id__,
logprobs: bool | None = None, metadata={},
max_tokens: int | None = None, model_type=ModelType.llm,
n: int | None = None, )
presence_penalty: float | None = None, )
seed: int | None = None, return models
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params) # type: ignore
async def openai_chat_completion( # LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
self, # So we need to implement our own method to list models by calling the watsonx.ai API.
model: str, def _get_model_specs(self) -> list[dict[str, Any]]:
messages: list[OpenAIMessageParam], """
frequency_penalty: float | None = None, Retrieves foundation model specifications from the watsonx.ai API.
function_call: str | dict[str, Any] | None = None, """
functions: list[dict[str, Any]] | None = None, url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
logit_bias: dict[str, float] | None = None, headers = {
logprobs: bool | None = None, # Note that there is no authorization header. Listing models does not require authentication.
max_completion_tokens: int | None = None, "Content-Type": "application/json",
max_tokens: int | None = None, }
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator: response = requests.get(url, headers=headers)
# watsonx.ai sometimes adds usage data to the stream
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
seen_finish_reason = False # --- Process the Response ---
async for chunk in stream: # Raise an exception for bad status codes (4xx or 5xx)
# Final usage chunk with no choices that the user didn't request, so discard response.raise_for_status()
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break # If the request is successful, parse and return the JSON response.
yield chunk # The response should contain a list of model specifications
for choice in chunk.choices: response_data = response.json()
if choice.finish_reason: if "resources" not in response_data:
seen_finish_reason = True raise ValueError("Resources not found in response")
break return response_data["resources"]
# TO DO: Delete the test main method.
if __name__ == "__main__":
config = WatsonXConfig(url="https://us-south.ml.cloud.ibm.com", api_key="xxx", project_id="xxx", timeout=60)
adapter = WatsonXInferenceAdapter(config)
model_specs = adapter._get_model_specs()
models = asyncio.run(adapter.list_models())
for model in models:
print(model.identifier)
print(model.provider_resource_id)
print(model.provider_id)
print(model.metadata)
print(model.model_type)
print("--------------------------------")

View file

@ -16,6 +16,8 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
def test_groq_provider_openai_client_caching(): def test_groq_provider_openai_client_caching():
@ -36,6 +38,24 @@ def test_groq_provider_openai_client_caching():
assert inference_adapter.client.api_key == api_key assert inference_adapter.client.api_key == api_key
def test_watsonx_provider_openai_client_caching():
"""Ensure the WatsonX provider does not cache api keys across client requests"""
config = WatsonXConfig()
inference_adapter = WatsonXInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
assert inference_adapter.client.api_key == api_key
def test_openai_provider_openai_client_caching(): def test_openai_provider_openai_client_caching():
"""Ensure the OpenAI provider does not cache api keys across client requests""" """Ensure the OpenAI provider does not cache api keys across client requests"""