From 5d711d4bcb1e7e84ca5325f88f3943b1fb274d79 Mon Sep 17 00:00:00 2001 From: Bill Murdock Date: Wed, 8 Oct 2025 07:29:43 -0400 Subject: [PATCH] fix: Update watsonx.ai provider to use LiteLLM mixin and list all models (#3674) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? - The watsonx.ai provider now uses the LiteLLM mixin instead of using IBM's library, which does not seem to be working (see #3165 for context). - The watsonx.ai provider now lists all the models available by calling the watsonx.ai server instead of having a hard coded list of known models. (That list gets out of date quickly) - An edge case in [llama_stack/core/routers/inference.py](https://github.com/llamastack/llama-stack/pull/3674/files#diff-a34bc966ed9befd9f13d4883c23705dff49be0ad6211c850438cdda6113f3455) is addressed that was causing my manual tests to fail. - Fixes `b64_encode_openai_embeddings_response` which was trying to enumerate over a dictionary and then reference elements of the dictionary using .field instead of ["field"]. That method is called by the LiteLLM mixin for embedding models, so it is needed to get the watsonx.ai embedding models to work. - A unit test along the lines of the one in #3348 is added. A more comprehensive plan for automatically testing the end-to-end functionality for inference providers would be a good idea, but is out of scope for this PR. - Updates to the watsonx distribution. Some were in response to the switch to LiteLLM (e.g., updating the Python packages needed). Others seem to be things that were already broken that I found along the way (e.g., a reference to a watsonx specific doc template that doesn't seem to exist). Closes #3165 Also it is related to a line-item in #3387 but doesn't really address that goal (because it uses the LiteLLM mixin, not the OpenAI one). I tried the OpenAI one and it doesn't work with watsonx.ai, presumably because the watsonx.ai service is not OpenAI compatible. It works with LiteLLM because LiteLLM has a provider implementation for watsonx.ai. ## Test Plan The test script below goes back and forth between the OpenAI and watsonx providers. The idea is that the OpenAI provider shows how it should work and then the watsonx provider output shows that it is also working with watsonx. Note that the result from the MCP test is not as good (the Llama 3.3 70b model does not choose tools as wisely as gpt-4o), but it is still working and providing a valid response. For more details on setup and the MCP server being used for testing, see [the AI Alliance sample notebook](https://github.com/The-AI-Alliance/llama-stack-examples/blob/main/notebooks/01-responses/) that these examples are drawn from. ```python #!/usr/bin/env python3 import json from llama_stack_client import LlamaStackClient from litellm import completion import http.client def print_response(response): """Print response in a nicely formatted way""" print(f"ID: {response.id}") print(f"Status: {response.status}") print(f"Model: {response.model}") print(f"Created at: {response.created_at}") print(f"Output items: {len(response.output)}") for i, output_item in enumerate(response.output): if len(response.output) > 1: print(f"\n--- Output Item {i+1} ---") print(f"Output type: {output_item.type}") if output_item.type in ("text", "message"): print(f"Response content: {output_item.content[0].text}") elif output_item.type == "file_search_call": print(f" Tool Call ID: {output_item.id}") print(f" Tool Status: {output_item.status}") # 'queries' is a list, so we join it for clean printing print(f" Queries: {', '.join(output_item.queries)}") # Display results if they exist, otherwise note they are empty print(f" Results: {output_item.results if output_item.results else 'None'}") elif output_item.type == "mcp_list_tools": print_mcp_list_tools(output_item) elif output_item.type == "mcp_call": print_mcp_call(output_item) else: print(f"Response content: {output_item.content}") def print_mcp_call(mcp_call): """Print MCP call in a nicely formatted way""" print(f"\n🛠️ MCP Tool Call: {mcp_call.name}") print(f" Server: {mcp_call.server_label}") print(f" ID: {mcp_call.id}") print(f" Arguments: {mcp_call.arguments}") if mcp_call.error: print("Error: {mcp_call.error}") elif mcp_call.output: print("Output:") # Try to format JSON output nicely try: parsed_output = json.loads(mcp_call.output) print(json.dumps(parsed_output, indent=4)) except: # If not valid JSON, print as-is print(f" {mcp_call.output}") else: print(" ⏳ No output yet") def print_mcp_list_tools(mcp_list_tools): """Print MCP list tools in a nicely formatted way""" print(f"\n🔧 MCP Server: {mcp_list_tools.server_label}") print(f" ID: {mcp_list_tools.id}") print(f" Available Tools: {len(mcp_list_tools.tools)}") print("=" * 80) for i, tool in enumerate(mcp_list_tools.tools, 1): print(f"\n{i}. {tool.name}") print(f" Description: {tool.description}") # Parse and display input schema schema = tool.input_schema if schema and 'properties' in schema: properties = schema['properties'] required = schema.get('required', []) print(" Parameters:") for param_name, param_info in properties.items(): param_type = param_info.get('type', 'unknown') param_desc = param_info.get('description', 'No description') required_marker = " (required)" if param_name in required else " (optional)" print(f" • {param_name} ({param_type}){required_marker}") if param_desc: print(f" {param_desc}") if i < len(mcp_list_tools.tools): print("-" * 40) def main(): """Main function to run all the tests""" # Configuration LLAMA_STACK_URL = "http://localhost:8321/" LLAMA_STACK_MODEL_IDS = [ "openai/gpt-3.5-turbo", "openai/gpt-4o", "llama-openai-compat/Llama-3.3-70B-Instruct", "watsonx/meta-llama/llama-3-3-70b-instruct" ] # Using gpt-4o for this demo, but feel free to try one of the others or add more to run.yaml. OPENAI_MODEL_ID = LLAMA_STACK_MODEL_IDS[1] WATSONX_MODEL_ID = LLAMA_STACK_MODEL_IDS[-1] NPS_MCP_URL = "http://localhost:3005/sse/" print("=== Llama Stack Testing Script ===") print(f"Using OpenAI model: {OPENAI_MODEL_ID}") print(f"Using WatsonX model: {WATSONX_MODEL_ID}") print(f"MCP URL: {NPS_MCP_URL}") print() # Initialize client print("Initializing LlamaStackClient...") client = LlamaStackClient(base_url="http://localhost:8321") # Test 1: List models print("\n=== Test 1: List Models ===") try: models = client.models.list() print(f"Found {len(models)} models") except Exception as e: print(f"Error listing models: {e}") raise e # Test 2: Basic chat completion with OpenAI print("\n=== Test 2: Basic Chat Completion (OpenAI) ===") try: chat_completion_response = client.chat.completions.create( model=OPENAI_MODEL_ID, messages=[{"role": "user", "content": "What is the capital of France?"}] ) print("OpenAI Response:") for chunk in chat_completion_response.choices[0].message.content: print(chunk, end="", flush=True) print() except Exception as e: print(f"Error with OpenAI chat completion: {e}") raise e # Test 3: Basic chat completion with WatsonX print("\n=== Test 3: Basic Chat Completion (WatsonX) ===") try: chat_completion_response_wxai = client.chat.completions.create( model=WATSONX_MODEL_ID, messages=[{"role": "user", "content": "What is the capital of France?"}], ) print("WatsonX Response:") for chunk in chat_completion_response_wxai.choices[0].message.content: print(chunk, end="", flush=True) print() except Exception as e: print(f"Error with WatsonX chat completion: {e}") raise e # Test 4: Tool calling with OpenAI print("\n=== Test 4: Tool Calling (OpenAI) ===") tools = [ { "type": "function", "function": { "name": "get_current_weather", "description": "Get the current weather for a specific location", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g., San Francisco, CA", }, "unit": { "type": "string", "enum": ["celsius", "fahrenheit"] }, }, "required": ["location"], }, }, } ] messages = [ {"role": "user", "content": "What's the weather like in Boston, MA?"} ] try: print("--- Initial API Call ---") response = client.chat.completions.create( model=OPENAI_MODEL_ID, messages=messages, tools=tools, tool_choice="auto", # "auto" is the default ) print("OpenAI tool calling response received") except Exception as e: print(f"Error with OpenAI tool calling: {e}") raise e # Test 5: Tool calling with WatsonX print("\n=== Test 5: Tool Calling (WatsonX) ===") try: wxai_response = client.chat.completions.create( model=WATSONX_MODEL_ID, messages=messages, tools=tools, tool_choice="auto", # "auto" is the default ) print("WatsonX tool calling response received") except Exception as e: print(f"Error with WatsonX tool calling: {e}") raise e # Test 6: Streaming with WatsonX print("\n=== Test 6: Streaming Response (WatsonX) ===") try: chat_completion_response_wxai_stream = client.chat.completions.create( model=WATSONX_MODEL_ID, messages=[{"role": "user", "content": "What is the capital of France?"}], stream=True ) print("Model response: ", end="") for chunk in chat_completion_response_wxai_stream: # Each 'chunk' is a ChatCompletionChunk object. # We want the content from the 'delta' attribute. if hasattr(chunk, 'choices') and chunk.choices is not None: content = chunk.choices[0].delta.content # The first few chunks might have None content, so we check for it. if content is not None: print(content, end="", flush=True) print() except Exception as e: print(f"Error with streaming: {e}") raise e # Test 7: MCP with OpenAI print("\n=== Test 7: MCP Integration (OpenAI) ===") try: mcp_llama_stack_client_response = client.responses.create( model=OPENAI_MODEL_ID, input="Tell me about some parks in Rhode Island, and let me know if there are any upcoming events at them.", tools=[ { "type": "mcp", "server_url": NPS_MCP_URL, "server_label": "National Parks Service tools", "allowed_tools": ["search_parks", "get_park_events"], } ] ) print_response(mcp_llama_stack_client_response) except Exception as e: print(f"Error with MCP (OpenAI): {e}") raise e # Test 8: MCP with WatsonX print("\n=== Test 8: MCP Integration (WatsonX) ===") try: mcp_llama_stack_client_response = client.responses.create( model=WATSONX_MODEL_ID, input="What is the capital of France?" ) print_response(mcp_llama_stack_client_response) except Exception as e: print(f"Error with MCP (WatsonX): {e}") raise e # Test 9: MCP with Llama 3.3 print("\n=== Test 9: MCP Integration (Llama 3.3) ===") try: mcp_llama_stack_client_response = client.responses.create( model=WATSONX_MODEL_ID, input="Tell me about some parks in Rhode Island, and let me know if there are any upcoming events at them.", tools=[ { "type": "mcp", "server_url": NPS_MCP_URL, "server_label": "National Parks Service tools", "allowed_tools": ["search_parks", "get_park_events"], } ] ) print_response(mcp_llama_stack_client_response) except Exception as e: print(f"Error with MCP (Llama 3.3): {e}") raise e # Test 10: Embeddings print("\n=== Test 10: Embeddings ===") try: conn = http.client.HTTPConnection("localhost:8321") payload = json.dumps({ "model": "watsonx/ibm/granite-embedding-278m-multilingual", "input": "Hello, world!", }) headers = { 'Content-Type': 'application/json', 'Accept': 'application/json' } conn.request("POST", "/v1/openai/v1/embeddings", payload, headers) res = conn.getresponse() data = res.read() print(data.decode("utf-8")) except Exception as e: print(f"Error with Embeddings: {e}") raise e print("\n=== Testing Complete ===") if __name__ == "__main__": main() ``` --------- Signed-off-by: Bill Murdock Co-authored-by: github-actions[bot] --- .../providers/inference/remote_watsonx.mdx | 4 +- llama_stack/core/routers/inference.py | 2 +- llama_stack/distributions/watsonx/__init__.py | 2 + llama_stack/distributions/watsonx/build.yaml | 41 +-- llama_stack/distributions/watsonx/run.yaml | 103 +----- llama_stack/distributions/watsonx/watsonx.py | 36 +- llama_stack/providers/registry/inference.py | 2 +- .../remote/inference/watsonx/__init__.py | 11 +- .../remote/inference/watsonx/config.py | 22 +- .../remote/inference/watsonx/models.py | 47 --- .../remote/inference/watsonx/watsonx.py | 324 ++++++------------ .../utils/inference/litellm_openai_mixin.py | 29 +- .../utils/inference/openai_compat.py | 28 -- .../test_inference_client_caching.py | 28 ++ 14 files changed, 203 insertions(+), 476 deletions(-) delete mode 100644 llama_stack/providers/remote/inference/watsonx/models.py diff --git a/docs/docs/providers/inference/remote_watsonx.mdx b/docs/docs/providers/inference/remote_watsonx.mdx index 33bc5bbc3..f081703ab 100644 --- a/docs/docs/providers/inference/remote_watsonx.mdx +++ b/docs/docs/providers/inference/remote_watsonx.mdx @@ -17,8 +17,8 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai | -| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key | -| `project_id` | `str \| None` | No | | The Project ID key | +| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx.ai API key | +| `project_id` | `str \| None` | No | | The watsonx.ai project ID | | `timeout` | `` | No | 60 | Timeout for the HTTP requests | ## Sample Configuration diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index c4338e614..847f6a2d2 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -611,7 +611,7 @@ class InferenceRouter(Inference): completion_text += "".join(choice_data["content_parts"]) # Add metrics to the chunk - if self.telemetry and chunk.usage: + if self.telemetry and hasattr(chunk, "usage") and chunk.usage: metrics = self._construct_metrics( prompt_tokens=chunk.usage.prompt_tokens, completion_tokens=chunk.usage.completion_tokens, diff --git a/llama_stack/distributions/watsonx/__init__.py b/llama_stack/distributions/watsonx/__init__.py index 756f351d8..078d86144 100644 --- a/llama_stack/distributions/watsonx/__init__.py +++ b/llama_stack/distributions/watsonx/__init__.py @@ -3,3 +3,5 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +from .watsonx import get_distribution_template # noqa: F401 diff --git a/llama_stack/distributions/watsonx/build.yaml b/llama_stack/distributions/watsonx/build.yaml index bf4be7eaf..06349a741 100644 --- a/llama_stack/distributions/watsonx/build.yaml +++ b/llama_stack/distributions/watsonx/build.yaml @@ -3,44 +3,33 @@ distribution_spec: description: Use watsonx for running LLM inference providers: inference: - - provider_id: watsonx - provider_type: remote::watsonx - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers + - provider_type: remote::watsonx + - provider_type: inline::sentence-transformers vector_io: - - provider_id: faiss - provider_type: inline::faiss + - provider_type: inline::faiss safety: - - provider_id: llama-guard - provider_type: inline::llama-guard + - provider_type: inline::llama-guard agents: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_type: inline::meta-reference datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - - provider_id: localfs - provider_type: inline::localfs + - provider_type: remote::huggingface + - provider_type: inline::localfs scoring: - - provider_id: basic - provider_type: inline::basic - - provider_id: llm-as-judge - provider_type: inline::llm-as-judge - - provider_id: braintrust - provider_type: inline::braintrust + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust tool_runtime: - provider_type: remote::brave-search - provider_type: remote::tavily-search - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol + files: + - provider_type: inline::localfs image_type: venv additional_pip_packages: +- aiosqlite - sqlalchemy[asyncio] -- aiosqlite -- aiosqlite diff --git a/llama_stack/distributions/watsonx/run.yaml b/llama_stack/distributions/watsonx/run.yaml index 92f367910..e0c337f9d 100644 --- a/llama_stack/distributions/watsonx/run.yaml +++ b/llama_stack/distributions/watsonx/run.yaml @@ -4,13 +4,13 @@ apis: - agents - datasetio - eval +- files - inference - safety - scoring - telemetry - tool_runtime - vector_io -- files providers: inference: - provider_id: watsonx @@ -19,8 +19,6 @@ providers: url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com} api_key: ${env.WATSONX_API_KEY:=} project_id: ${env.WATSONX_PROJECT_ID:=} - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers vector_io: - provider_id: faiss provider_type: inline::faiss @@ -48,7 +46,7 @@ providers: provider_type: inline::meta-reference config: service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" - sinks: ${env.TELEMETRY_SINKS:=console,sqlite} + sinks: ${env.TELEMETRY_SINKS:=sqlite} sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} eval: @@ -109,102 +107,7 @@ metadata_store: inference_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db -models: -- metadata: {} - model_id: meta-llama/llama-3-3-70b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-3-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.3-70B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-3-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-2-13b-chat - provider_id: watsonx - provider_model_id: meta-llama/llama-2-13b-chat - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-2-13b - provider_id: watsonx - provider_model_id: meta-llama/llama-2-13b-chat - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-1-70b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-1-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-70B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-1-70b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-1-8b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-1-8b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.1-8B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-1-8b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-2-11b-vision-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-11b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-11B-Vision-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-11b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-2-1b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-1b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-1B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-1b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-2-3b-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-3b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-3B-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-3b-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-3-2-90b-vision-instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-90b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-3.2-90B-Vision-Instruct - provider_id: watsonx - provider_model_id: meta-llama/llama-3-2-90b-vision-instruct - model_type: llm -- metadata: {} - model_id: meta-llama/llama-guard-3-11b-vision - provider_id: watsonx - provider_model_id: meta-llama/llama-guard-3-11b-vision - model_type: llm -- metadata: {} - model_id: meta-llama/Llama-Guard-3-11B-Vision - provider_id: watsonx - provider_model_id: meta-llama/llama-guard-3-11b-vision - model_type: llm -- metadata: - embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 - provider_id: sentence-transformers - model_type: embedding +models: [] shields: [] vector_dbs: [] datasets: [] diff --git a/llama_stack/distributions/watsonx/watsonx.py b/llama_stack/distributions/watsonx/watsonx.py index c3cab5d1b..645770612 100644 --- a/llama_stack/distributions/watsonx/watsonx.py +++ b/llama_stack/distributions/watsonx/watsonx.py @@ -4,17 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pathlib import Path -from llama_stack.apis.models import ModelType -from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput -from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry +from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig -from llama_stack.providers.inline.inference.sentence_transformers import ( - SentenceTransformersInferenceConfig, -) from llama_stack.providers.remote.inference.watsonx import WatsonXConfig -from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: @@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: config=WatsonXConfig.sample_run_config(), ) - embedding_provider = Provider( - provider_id="sentence-transformers", - provider_type="inline::sentence-transformers", - config=SentenceTransformersInferenceConfig.sample_run_config(), - ) - - available_models = { - "watsonx": MODEL_ENTRIES, - } default_tool_groups = [ ToolGroupInput( toolgroup_id="builtin::websearch", @@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: ), ] - embedding_model = ModelInput( - model_id="all-MiniLM-L6-v2", - provider_id="sentence-transformers", - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": 384, - }, - ) - files_provider = Provider( provider_id="meta-reference-files", provider_type="inline::localfs", config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"), ) - default_models, _ = get_model_registry(available_models) return DistributionTemplate( name=name, distro_type="remote_hosted", description="Use watsonx for running LLM inference", container_image=None, - template_path=Path(__file__).parent / "doc_template.md", + template_path=None, providers=providers, - available_models_by_provider=available_models, run_configs={ "run.yaml": RunConfigSettings( provider_overrides={ - "inference": [inference_provider, embedding_provider], + "inference": [inference_provider], "files": [files_provider], }, - default_models=default_models + [embedding_model], + default_models=[], default_tool_groups=default_tool_groups, ), }, diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index bf6a09b6c..f89565892 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -268,7 +268,7 @@ Available Models: api=Api.inference, adapter_type="watsonx", provider_type="remote::watsonx", - pip_packages=["ibm_watsonx_ai"], + pip_packages=["litellm"], module="llama_stack.providers.remote.inference.watsonx", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", diff --git a/llama_stack/providers/remote/inference/watsonx/__init__.py b/llama_stack/providers/remote/inference/watsonx/__init__.py index e59e873b6..35e74a720 100644 --- a/llama_stack/providers/remote/inference/watsonx/__init__.py +++ b/llama_stack/providers/remote/inference/watsonx/__init__.py @@ -4,19 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.inference import Inference - from .config import WatsonXConfig -async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference: - # import dynamically so `llama stack build` does not fail due to missing dependencies +async def get_adapter_impl(config: WatsonXConfig, _deps): + # import dynamically so the import is used only when it is needed from .watsonx import WatsonXInferenceAdapter - if not isinstance(config, WatsonXConfig): - raise RuntimeError(f"Unexpected config type: {type(config)}") adapter = WatsonXInferenceAdapter(config) return adapter - - -__all__ = ["get_adapter_impl", "WatsonXConfig"] diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 4bc0173c4..9e98d4003 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -7,16 +7,18 @@ import os from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import BaseModel, ConfigDict, Field, SecretStr from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type class WatsonXProviderDataValidator(BaseModel): - url: str - api_key: str - project_id: str + model_config = ConfigDict( + from_attributes=True, + extra="forbid", + ) + watsonx_api_key: str | None @json_schema_type @@ -25,13 +27,17 @@ class WatsonXConfig(RemoteInferenceProviderConfig): default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"), description="A base url for accessing the watsonx.ai", ) + # This seems like it should be required, but none of the other remote inference + # providers require it, so this is optional here too for consistency. + # The OpenAIConfig uses default=None instead, so this is following that precedent. api_key: SecretStr | None = Field( - default_factory=lambda: os.getenv("WATSONX_API_KEY"), - description="The watsonx API key", + default=None, + description="The watsonx.ai API key", ) + # As above, this is optional here too for consistency. project_id: str | None = Field( - default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"), - description="The Project ID key", + default=None, + description="The watsonx.ai project ID", ) timeout: int = Field( default=60, diff --git a/llama_stack/providers/remote/inference/watsonx/models.py b/llama_stack/providers/remote/inference/watsonx/models.py deleted file mode 100644 index d98f0510a..000000000 --- a/llama_stack/providers/remote/inference/watsonx/models.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry - -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "meta-llama/llama-3-3-70b-instruct", - CoreModelId.llama3_3_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-2-13b-chat", - CoreModelId.llama2_13b.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-1-70b-instruct", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-1-8b-instruct", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-11b-vision-instruct", - CoreModelId.llama3_2_11b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-1b-instruct", - CoreModelId.llama3_2_1b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-3b-instruct", - CoreModelId.llama3_2_3b_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-3-2-90b-vision-instruct", - CoreModelId.llama3_2_90b_vision_instruct.value, - ), - build_hf_repo_model_entry( - "meta-llama/llama-guard-3-11b-vision", - CoreModelId.llama_guard_3_11b_vision.value, - ), -] diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index fc58691e2..d04472936 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,240 +4,120 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import AsyncGenerator, AsyncIterator from typing import Any -from ibm_watsonx_ai.foundation_models import Model -from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams -from openai import AsyncOpenAI +import requests -from llama_stack.apis.inference import ( - ChatCompletionRequest, - CompletionRequest, - GreedySamplingStrategy, - Inference, - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAICompletion, - OpenAIEmbeddingsResponse, - OpenAIMessageParam, - OpenAIResponseFormatParam, - TopKSamplingStrategy, - TopPSamplingStrategy, -) -from llama_stack.log import get_logger -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper -from llama_stack.providers.utils.inference.openai_compat import ( - prepare_openai_completion_params, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_prompt, - completion_request_to_prompt, - request_has_media, -) - -from . import WatsonXConfig -from .models import MODEL_ENTRIES - -logger = get_logger(name=__name__, category="inference::watsonx") +from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.apis.models import Model +from llama_stack.apis.models.models import ModelType +from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -# Note on structured output -# WatsonX returns responses with a json embedded into a string. -# Examples: +class WatsonXInferenceAdapter(LiteLLMOpenAIMixin): + _model_cache: dict[str, Model] = {} -# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n -# "first_name": "Michael",\n "last_name": "Jordan",\n'...) -# Not even a valid JSON, but we can still extract the JSON from the content + def __init__(self, config: WatsonXConfig): + LiteLLMOpenAIMixin.__init__( + self, + litellm_provider_name="watsonx", + api_key_from_config=config.api_key.get_secret_value() if config.api_key else None, + provider_data_api_key_field="watsonx_api_key", + ) + self.available_models = None + self.config = config -# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan", -# "year_born": "1963", "year_retired": "2003"\\}}$') -# Find the start of the boxed content + def get_base_url(self) -> str: + return self.config.url + async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: + # Get base parameters from parent + params = await super()._get_params(request) -class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): - def __init__(self, config: WatsonXConfig) -> None: - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) - - logger.info(f"Initializing watsonx InferenceAdapter({config.url})...") - self._config = config - self._openai_client: AsyncOpenAI | None = None - - self._project_id = self._config.project_id - - def _get_client(self, model_id) -> Model: - config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None - config_url = self._config.url - project_id = self._config.project_id - credentials = {"url": config_url, "apikey": config_api_key} - - return Model(model_id=model_id, credentials=credentials, project_id=project_id) - - def _get_openai_client(self) -> AsyncOpenAI: - if not self._openai_client: - self._openai_client = AsyncOpenAI( - base_url=f"{self._config.url}/openai/v1", - api_key=self._config.api_key, - ) - return self._openai_client - - async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: - input_dict = {"params": {}} - media_present = request_has_media(request) - llama_model = self.get_llama_model(request.model) - if isinstance(request, ChatCompletionRequest): - input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) - else: - assert not media_present, "Together does not support media for Completion requests" - input_dict["prompt"] = await completion_request_to_prompt(request) - if request.sampling_params: - if request.sampling_params.strategy: - input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type - if request.sampling_params.max_tokens: - input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens - if request.sampling_params.repetition_penalty: - input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty - - if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): - input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p - input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature - if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): - input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k - if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): - input_dict["params"][GenParams.TEMPERATURE] = 0.0 - - input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"] - - params = { - **input_dict, - } + # Add watsonx.ai specific parameters + params["project_id"] = self.config.project_id + params["time_limit"] = self.config.timeout return params - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() + # Copied from OpenAIMixin + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from the provider's /v1/models. - async def openai_completion( - self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, - ) -> OpenAICompletion: - model_obj = await self.model_store.get_model(model) - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - ) - return await self._get_openai_client().completions.create(**params) # type: ignore + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + if not self._model_cache: + await self.list_models() + return model in self._model_cache - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - model_obj = await self.model_store.get_model(model) - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) - if params.get("stream", False): - return self._stream_openai_chat_completion(params) - return await self._get_openai_client().chat.completions.create(**params) # type: ignore + async def list_models(self) -> list[Model] | None: + self._model_cache = {} + models = [] + for model_spec in self._get_model_specs(): + functions = [f["id"] for f in model_spec.get("functions", [])] + # Format: {"embedding_dimension": 1536, "context_length": 8192} - async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator: - # watsonx.ai sometimes adds usage data to the stream - include_usage = False - if params.get("stream_options", None): - include_usage = params["stream_options"].get("include_usage", False) - stream = await self._get_openai_client().chat.completions.create(**params) + # Example of an embedding model: + # {'model_id': 'ibm/granite-embedding-278m-multilingual', + # 'label': 'granite-embedding-278m-multilingual', + # 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768}, + # ... + provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}" + if "embedding" in functions: + embedding_dimension = model_spec["model_limits"]["embedding_dimension"] + context_length = model_spec["model_limits"]["max_sequence_length"] + embedding_metadata = { + "embedding_dimension": embedding_dimension, + "context_length": context_length, + } + model = Model( + identifier=model_spec["model_id"], + provider_resource_id=provider_resource_id, + provider_id=self.__provider_id__, + metadata=embedding_metadata, + model_type=ModelType.embedding, + ) + self._model_cache[provider_resource_id] = model + models.append(model) + if "text_chat" in functions: + model = Model( + identifier=model_spec["model_id"], + provider_resource_id=provider_resource_id, + provider_id=self.__provider_id__, + metadata={}, + model_type=ModelType.llm, + ) + # In theory, I guess it is possible that a model could be both an embedding model and a text chat model. + # In that case, the cache will record the generator Model object, and the list which we return will have + # both the generator Model object and the text chat Model object. That's fine because the cache is + # only used for check_model_availability() anyway. + self._model_cache[provider_resource_id] = model + models.append(model) + return models - seen_finish_reason = False - async for chunk in stream: - # Final usage chunk with no choices that the user didn't request, so discard - if not include_usage and seen_finish_reason and len(chunk.choices) == 0: - break - yield chunk - for choice in chunk.choices: - if choice.finish_reason: - seen_finish_reason = True - break + # LiteLLM provides methods to list models for many providers, but not for watsonx.ai. + # So we need to implement our own method to list models by calling the watsonx.ai API. + def _get_model_specs(self) -> list[dict[str, Any]]: + """ + Retrieves foundation model specifications from the watsonx.ai API. + """ + url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25" + headers = { + # Note that there is no authorization header. Listing models does not require authentication. + "Content-Type": "application/json", + } + + response = requests.get(url, headers=headers) + + # --- Process the Response --- + # Raise an exception for bad status codes (4xx or 5xx) + response.raise_for_status() + + # If the request is successful, parse and return the JSON response. + # The response should contain a list of model specifications + response_data = response.json() + if "resources" not in response_data: + raise ValueError("Resources not found in response") + return response_data["resources"] diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 6c8f61c3b..6bef97dd5 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 +import struct from collections.abc import AsyncIterator from typing import Any @@ -16,6 +18,7 @@ from llama_stack.apis.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, + OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, OpenAIMessageParam, @@ -26,7 +29,6 @@ from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry from llama_stack.providers.utils.inference.openai_compat import ( - b64_encode_openai_embeddings_response, convert_message_to_openai_dict_new, convert_tooldef_to_openai_tool, get_sampling_options, @@ -349,3 +351,28 @@ class LiteLLMOpenAIMixin( return False return model in litellm.models_by_provider[self.litellm_provider_name] + + +def b64_encode_openai_embeddings_response( + response_data: list[dict], encoding_format: str | None = "float" +) -> list[OpenAIEmbeddingData]: + """ + Process the OpenAI embeddings response to encode the embeddings in base64 format if specified. + """ + data = [] + for i, embedding_data in enumerate(response_data): + if encoding_format == "base64": + byte_array = bytearray() + for embedding_value in embedding_data["embedding"]: + byte_array.extend(struct.pack("f", float(embedding_value))) + + response_embedding = base64.b64encode(byte_array).decode("utf-8") + else: + response_embedding = embedding_data["embedding"] + data.append( + OpenAIEmbeddingData( + embedding=response_embedding, + index=i, + ) + ) + return data diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index d863eb53a..7e465a14c 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -3,9 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import base64 import json -import struct import time import uuid import warnings @@ -103,7 +101,6 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, Message, OpenAIChatCompletion, - OpenAIEmbeddingData, OpenAIMessageParam, OpenAIResponseFormatParam, SamplingParams, @@ -1402,28 +1399,3 @@ def prepare_openai_embeddings_params( params["user"] = user return params - - -def b64_encode_openai_embeddings_response( - response_data: dict, encoding_format: str | None = "float" -) -> list[OpenAIEmbeddingData]: - """ - Process the OpenAI embeddings response to encode the embeddings in base64 format if specified. - """ - data = [] - for i, embedding_data in enumerate(response_data): - if encoding_format == "base64": - byte_array = bytearray() - for embedding_value in embedding_data.embedding: - byte_array.extend(struct.pack("f", float(embedding_value))) - - response_embedding = base64.b64encode(byte_array).decode("utf-8") - else: - response_embedding = embedding_data.embedding - data.append( - OpenAIEmbeddingData( - embedding=response_embedding, - index=i, - ) - ) - return data diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py index d30b5b12a..55a6793c2 100644 --- a/tests/unit/providers/inference/test_inference_client_caching.py +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -18,6 +18,8 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter +from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig +from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter @pytest.mark.parametrize( @@ -58,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} ): assert inference_adapter.client.api_key == api_key + + +@pytest.mark.parametrize( + "config_cls,adapter_cls,provider_data_validator", + [ + ( + WatsonXConfig, + WatsonXInferenceAdapter, + "llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator", + ), + ], +) +def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str): + """Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the + assumption that there is an OpenAI-compatible client object.""" + + inference_adapter = adapter_cls(config=config_cls()) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator + + for api_key in ["test1", "test2"]: + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} + ): + assert inference_adapter.get_api_key() == api_key