fix: Update watsonx.ai provider to use LiteLLM mixin and list all models (#3674)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 2s
Python Package Build Test / build (3.12) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 7s
Test Llama Stack Build / generate-matrix (push) Successful in 6s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 12s
Test Llama Stack Build / build (push) Failing after 3s
Unit Tests / unit-tests (3.12) (push) Failing after 5s
UI Tests / ui-tests (22) (push) Successful in 32s
Pre-commit / pre-commit (push) Successful in 1m29s

# What does this PR do?

- The watsonx.ai provider now uses the LiteLLM mixin instead of using
IBM's library, which does not seem to be working (see #3165 for
context).
- The watsonx.ai provider now lists all the models available by calling
the watsonx.ai server instead of having a hard coded list of known
models. (That list gets out of date quickly)
- An edge case in
[llama_stack/core/routers/inference.py](https://github.com/llamastack/llama-stack/pull/3674/files#diff-a34bc966ed9befd9f13d4883c23705dff49be0ad6211c850438cdda6113f3455)
is addressed that was causing my manual tests to fail.
- Fixes `b64_encode_openai_embeddings_response` which was trying to
enumerate over a dictionary and then reference elements of the
dictionary using .field instead of ["field"]. That method is called by
the LiteLLM mixin for embedding models, so it is needed to get the
watsonx.ai embedding models to work.
- A unit test along the lines of the one in #3348 is added. A more
comprehensive plan for automatically testing the end-to-end
functionality for inference providers would be a good idea, but is out
of scope for this PR.
- Updates to the watsonx distribution. Some were in response to the
switch to LiteLLM (e.g., updating the Python packages needed). Others
seem to be things that were already broken that I found along the way
(e.g., a reference to a watsonx specific doc template that doesn't seem
to exist).

Closes #3165

Also it is related to a line-item in #3387 but doesn't really address
that goal (because it uses the LiteLLM mixin, not the OpenAI one). I
tried the OpenAI one and it doesn't work with watsonx.ai, presumably
because the watsonx.ai service is not OpenAI compatible. It works with
LiteLLM because LiteLLM has a provider implementation for watsonx.ai.

## Test Plan

The test script below goes back and forth between the OpenAI and watsonx
providers. The idea is that the OpenAI provider shows how it should work
and then the watsonx provider output shows that it is also working with
watsonx. Note that the result from the MCP test is not as good (the
Llama 3.3 70b model does not choose tools as wisely as gpt-4o), but it
is still working and providing a valid response. For more details on
setup and the MCP server being used for testing, see [the AI Alliance
sample
notebook](https://github.com/The-AI-Alliance/llama-stack-examples/blob/main/notebooks/01-responses/)
that these examples are drawn from.

```python
#!/usr/bin/env python3

import json
from llama_stack_client import LlamaStackClient
from litellm import completion
import http.client


def print_response(response):
    """Print response in a nicely formatted way"""
    print(f"ID: {response.id}")
    print(f"Status: {response.status}")
    print(f"Model: {response.model}")
    print(f"Created at: {response.created_at}")
    print(f"Output items: {len(response.output)}")
    
    for i, output_item in enumerate(response.output):
        if len(response.output) > 1:
            print(f"\n--- Output Item {i+1} ---")
        print(f"Output type: {output_item.type}")
        
        if output_item.type in ("text", "message"):
            print(f"Response content: {output_item.content[0].text}")
        elif output_item.type == "file_search_call":
            print(f"  Tool Call ID: {output_item.id}")
            print(f"  Tool Status: {output_item.status}")
            # 'queries' is a list, so we join it for clean printing
            print(f"  Queries: {', '.join(output_item.queries)}")
            # Display results if they exist, otherwise note they are empty
            print(f"  Results: {output_item.results if output_item.results else 'None'}")
        elif output_item.type == "mcp_list_tools":
            print_mcp_list_tools(output_item)
        elif output_item.type == "mcp_call":
            print_mcp_call(output_item)
        else:
            print(f"Response content: {output_item.content}")


def print_mcp_call(mcp_call):
    """Print MCP call in a nicely formatted way"""
    print(f"\n🛠️  MCP Tool Call: {mcp_call.name}")
    print(f"   Server: {mcp_call.server_label}")
    print(f"   ID: {mcp_call.id}")
    print(f"   Arguments: {mcp_call.arguments}")
    
    if mcp_call.error:
        print("Error: {mcp_call.error}")
    elif mcp_call.output:
        print("Output:")
        # Try to format JSON output nicely
        try:
            parsed_output = json.loads(mcp_call.output)
            print(json.dumps(parsed_output, indent=4))
        except:
            # If not valid JSON, print as-is
            print(f"   {mcp_call.output}")
    else:
        print("    No output yet")


def print_mcp_list_tools(mcp_list_tools):
    """Print MCP list tools in a nicely formatted way"""
    print(f"\n🔧 MCP Server: {mcp_list_tools.server_label}")
    print(f"   ID: {mcp_list_tools.id}")
    print(f"   Available Tools: {len(mcp_list_tools.tools)}")
    print("=" * 80)
    
    for i, tool in enumerate(mcp_list_tools.tools, 1):
        print(f"\n{i}. {tool.name}")
        print(f"   Description: {tool.description}")
        
        # Parse and display input schema
        schema = tool.input_schema
        if schema and 'properties' in schema:
            properties = schema['properties']
            required = schema.get('required', [])
            
            print("   Parameters:")
            for param_name, param_info in properties.items():
                param_type = param_info.get('type', 'unknown')
                param_desc = param_info.get('description', 'No description')
                required_marker = " (required)" if param_name in required else " (optional)"
                print(f"     • {param_name} ({param_type}){required_marker}")
                if param_desc:
                    print(f"       {param_desc}")
        
        if i < len(mcp_list_tools.tools):
            print("-" * 40)


def main():
    """Main function to run all the tests"""
    
    # Configuration
    LLAMA_STACK_URL = "http://localhost:8321/"
    LLAMA_STACK_MODEL_IDS = [
        "openai/gpt-3.5-turbo",
        "openai/gpt-4o",
        "llama-openai-compat/Llama-3.3-70B-Instruct",
        "watsonx/meta-llama/llama-3-3-70b-instruct"
    ]
    
    # Using gpt-4o for this demo, but feel free to try one of the others or add more to run.yaml.
    OPENAI_MODEL_ID = LLAMA_STACK_MODEL_IDS[1]
    WATSONX_MODEL_ID = LLAMA_STACK_MODEL_IDS[-1]
    NPS_MCP_URL = "http://localhost:3005/sse/"
    
    print("=== Llama Stack Testing Script ===")
    print(f"Using OpenAI model: {OPENAI_MODEL_ID}")
    print(f"Using WatsonX model: {WATSONX_MODEL_ID}")
    print(f"MCP URL: {NPS_MCP_URL}")
    print()
    
    # Initialize client
    print("Initializing LlamaStackClient...")
    client = LlamaStackClient(base_url="http://localhost:8321")
    
    # Test 1: List models
    print("\n=== Test 1: List Models ===")
    try:
        models = client.models.list()
        print(f"Found {len(models)} models")
    except Exception as e:
        print(f"Error listing models: {e}")
        raise e
    
    # Test 2: Basic chat completion with OpenAI
    print("\n=== Test 2: Basic Chat Completion (OpenAI) ===")
    try:
        chat_completion_response = client.chat.completions.create(
            model=OPENAI_MODEL_ID,
            messages=[{"role": "user", "content": "What is the capital of France?"}]
        )
        
        print("OpenAI Response:")
        for chunk in chat_completion_response.choices[0].message.content:
            print(chunk, end="", flush=True)
        print()
    except Exception as e:
        print(f"Error with OpenAI chat completion: {e}")
        raise e
    
    # Test 3: Basic chat completion with WatsonX
    print("\n=== Test 3: Basic Chat Completion (WatsonX) ===")
    try:
        chat_completion_response_wxai = client.chat.completions.create(
            model=WATSONX_MODEL_ID,
            messages=[{"role": "user", "content": "What is the capital of France?"}],
        )
        
        print("WatsonX Response:")
        for chunk in chat_completion_response_wxai.choices[0].message.content:
            print(chunk, end="", flush=True)
        print()
    except Exception as e:
        print(f"Error with WatsonX chat completion: {e}")
        raise e
    
    # Test 4: Tool calling with OpenAI
    print("\n=== Test 4: Tool Calling (OpenAI) ===")
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather for a specific location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "The city and state, e.g., San Francisco, CA",
                        },
                        "unit": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"]
                        },
                    },
                    "required": ["location"],
                },
            },
        }
    ]
    
    messages = [
        {"role": "user", "content": "What's the weather like in Boston, MA?"}
    ]
    
    try:
        print("--- Initial API Call ---")
        response = client.chat.completions.create(
            model=OPENAI_MODEL_ID,
            messages=messages,
            tools=tools,
            tool_choice="auto",  # "auto" is the default
        )
        print("OpenAI tool calling response received")
    except Exception as e:
        print(f"Error with OpenAI tool calling: {e}")
        raise e
    
    # Test 5: Tool calling with WatsonX
    print("\n=== Test 5: Tool Calling (WatsonX) ===")
    try:
        wxai_response = client.chat.completions.create(
            model=WATSONX_MODEL_ID,
            messages=messages,
            tools=tools,
            tool_choice="auto",  # "auto" is the default
        )
        print("WatsonX tool calling response received")
    except Exception as e:
        print(f"Error with WatsonX tool calling: {e}")
        raise e
    
    # Test 6: Streaming with WatsonX
    print("\n=== Test 6: Streaming Response (WatsonX) ===")
    try:
        chat_completion_response_wxai_stream = client.chat.completions.create(
            model=WATSONX_MODEL_ID,
            messages=[{"role": "user", "content": "What is the capital of France?"}],
            stream=True
        )
        print("Model response: ", end="")
        for chunk in chat_completion_response_wxai_stream:
            # Each 'chunk' is a ChatCompletionChunk object.
            # We want the content from the 'delta' attribute.
            if hasattr(chunk, 'choices') and chunk.choices is not None:
                content = chunk.choices[0].delta.content
                # The first few chunks might have None content, so we check for it.
                if content is not None:
                    print(content, end="", flush=True)
        print()
    except Exception as e:
        print(f"Error with streaming: {e}")
        raise e
    
    # Test 7: MCP with OpenAI
    print("\n=== Test 7: MCP Integration (OpenAI) ===")
    try:
        mcp_llama_stack_client_response = client.responses.create(
            model=OPENAI_MODEL_ID,
            input="Tell me about some parks in Rhode Island, and let me know if there are any upcoming events at them.",
            tools=[
                {
                    "type": "mcp",
                    "server_url": NPS_MCP_URL,
                    "server_label": "National Parks Service tools",
                    "allowed_tools": ["search_parks", "get_park_events"],
                }
            ]
        )
        print_response(mcp_llama_stack_client_response)
    except Exception as e:
        print(f"Error with MCP (OpenAI): {e}")
        raise e
    
    # Test 8: MCP with WatsonX
    print("\n=== Test 8: MCP Integration (WatsonX) ===")
    try:
        mcp_llama_stack_client_response = client.responses.create(
            model=WATSONX_MODEL_ID,
            input="What is the capital of France?"
        )
        print_response(mcp_llama_stack_client_response)
    except Exception as e:
        print(f"Error with MCP (WatsonX): {e}")
        raise e
    
    # Test 9: MCP with Llama 3.3
    print("\n=== Test 9: MCP Integration (Llama 3.3) ===")
    try:
        mcp_llama_stack_client_response = client.responses.create(
            model=WATSONX_MODEL_ID,
            input="Tell me about some parks in Rhode Island, and let me know if there are any upcoming events at them.",
            tools=[
                {
                    "type": "mcp",
                    "server_url": NPS_MCP_URL,
                    "server_label": "National Parks Service tools",
                    "allowed_tools": ["search_parks", "get_park_events"],
                }
            ]
        )
        print_response(mcp_llama_stack_client_response)
    except Exception as e:
        print(f"Error with MCP (Llama 3.3): {e}")
        raise e
    
    # Test 10: Embeddings
    print("\n=== Test 10: Embeddings ===")
    try:
        conn = http.client.HTTPConnection("localhost:8321")
        payload = json.dumps({
            "model": "watsonx/ibm/granite-embedding-278m-multilingual",
            "input": "Hello, world!",
        })
        headers = {
            'Content-Type': 'application/json',
            'Accept': 'application/json'
        }
        conn.request("POST", "/v1/openai/v1/embeddings", payload, headers)
        res = conn.getresponse()
        data = res.read()
        print(data.decode("utf-8"))
    except Exception as e:
        print(f"Error with Embeddings: {e}")
        raise e

    print("\n=== Testing Complete ===")


if __name__ == "__main__":
    main()
```

---------

Signed-off-by: Bill Murdock <bmurdock@redhat.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Bill Murdock 2025-10-08 07:29:43 -04:00 committed by GitHub
parent 62bac0aad4
commit 5d711d4bcb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 203 additions and 476 deletions

View file

@ -17,8 +17,8 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai | | `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx.ai API key |
| `project_id` | `str \| None` | No | | The Project ID key | | `project_id` | `str \| None` | No | | The watsonx.ai project ID |
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests | | `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
## Sample Configuration ## Sample Configuration

View file

@ -611,7 +611,7 @@ class InferenceRouter(Inference):
completion_text += "".join(choice_data["content_parts"]) completion_text += "".join(choice_data["content_parts"])
# Add metrics to the chunk # Add metrics to the chunk
if self.telemetry and chunk.usage: if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
metrics = self._construct_metrics( metrics = self._construct_metrics(
prompt_tokens=chunk.usage.prompt_tokens, prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens, completion_tokens=chunk.usage.completion_tokens,

View file

@ -3,3 +3,5 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .watsonx import get_distribution_template # noqa: F401

View file

@ -3,44 +3,33 @@ distribution_spec:
description: Use watsonx for running LLM inference description: Use watsonx for running LLM inference
providers: providers:
inference: inference:
- provider_id: watsonx - provider_type: remote::watsonx
provider_type: remote::watsonx - provider_type: inline::sentence-transformers
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io: vector_io:
- provider_id: faiss - provider_type: inline::faiss
provider_type: inline::faiss
safety: safety:
- provider_id: llama-guard - provider_type: inline::llama-guard
provider_type: inline::llama-guard
agents: agents:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
eval: eval:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
datasetio: datasetio:
- provider_id: huggingface - provider_type: remote::huggingface
provider_type: remote::huggingface - provider_type: inline::localfs
- provider_id: localfs
provider_type: inline::localfs
scoring: scoring:
- provider_id: basic - provider_type: inline::basic
provider_type: inline::basic - provider_type: inline::llm-as-judge
- provider_id: llm-as-judge - provider_type: inline::braintrust
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
tool_runtime: tool_runtime:
- provider_type: remote::brave-search - provider_type: remote::brave-search
- provider_type: remote::tavily-search - provider_type: remote::tavily-search
- provider_type: inline::rag-runtime - provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol - provider_type: remote::model-context-protocol
files:
- provider_type: inline::localfs
image_type: venv image_type: venv
additional_pip_packages: additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]
- aiosqlite
- aiosqlite

View file

@ -4,13 +4,13 @@ apis:
- agents - agents
- datasetio - datasetio
- eval - eval
- files
- inference - inference
- safety - safety
- scoring - scoring
- telemetry - telemetry
- tool_runtime - tool_runtime
- vector_io - vector_io
- files
providers: providers:
inference: inference:
- provider_id: watsonx - provider_id: watsonx
@ -19,8 +19,6 @@ providers:
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com} url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
api_key: ${env.WATSONX_API_KEY:=} api_key: ${env.WATSONX_API_KEY:=}
project_id: ${env.WATSONX_PROJECT_ID:=} project_id: ${env.WATSONX_PROJECT_ID:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io: vector_io:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
@ -48,7 +46,7 @@ providers:
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console,sqlite} sinks: ${env.TELEMETRY_SINKS:=sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
eval: eval:
@ -109,102 +107,7 @@ metadata_store:
inference_store: inference_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
models: models: []
- metadata: {}
model_id: meta-llama/llama-3-3-70b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-2-13b-chat
provider_id: watsonx
provider_model_id: meta-llama/llama-2-13b-chat
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-2-13b
provider_id: watsonx
provider_model_id: meta-llama/llama-2-13b-chat
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-1-70b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-1-8b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-11b-vision-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-1b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-3b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-90b-vision-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-guard-3-11b-vision
provider_id: watsonx
provider_model_id: meta-llama/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-11B-Vision
provider_id: watsonx
provider_model_id: meta-llama/llama-guard-3-11b-vision
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
shields: [] shields: []
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []

View file

@ -4,17 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pathlib import Path
from llama_stack.apis.models import ModelType from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate: def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
config=WatsonXConfig.sample_run_config(), config=WatsonXConfig.sample_run_config(),
) )
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
available_models = {
"watsonx": MODEL_ENTRIES,
}
default_tool_groups = [ default_tool_groups = [
ToolGroupInput( ToolGroupInput(
toolgroup_id="builtin::websearch", toolgroup_id="builtin::websearch",
@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
), ),
] ]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
files_provider = Provider( files_provider = Provider(
provider_id="meta-reference-files", provider_id="meta-reference-files",
provider_type="inline::localfs", provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"), config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
) )
default_models, _ = get_model_registry(available_models)
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,
distro_type="remote_hosted", distro_type="remote_hosted",
description="Use watsonx for running LLM inference", description="Use watsonx for running LLM inference",
container_image=None, container_image=None,
template_path=Path(__file__).parent / "doc_template.md", template_path=None,
providers=providers, providers=providers,
available_models_by_provider=available_models,
run_configs={ run_configs={
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [inference_provider, embedding_provider], "inference": [inference_provider],
"files": [files_provider], "files": [files_provider],
}, },
default_models=default_models + [embedding_model], default_models=[],
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
), ),
}, },

View file

@ -268,7 +268,7 @@ Available Models:
api=Api.inference, api=Api.inference,
adapter_type="watsonx", adapter_type="watsonx",
provider_type="remote::watsonx", provider_type="remote::watsonx",
pip_packages=["ibm_watsonx_ai"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.watsonx", module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",

View file

@ -4,19 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import WatsonXConfig from .config import WatsonXConfig
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference: async def get_adapter_impl(config: WatsonXConfig, _deps):
# import dynamically so `llama stack build` does not fail due to missing dependencies # import dynamically so the import is used only when it is needed
from .watsonx import WatsonXInferenceAdapter from .watsonx import WatsonXInferenceAdapter
if not isinstance(config, WatsonXConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = WatsonXInferenceAdapter(config) adapter = WatsonXInferenceAdapter(config)
return adapter return adapter
__all__ = ["get_adapter_impl", "WatsonXConfig"]

View file

@ -7,16 +7,18 @@
import os import os
from typing import Any from typing import Any
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, ConfigDict, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
class WatsonXProviderDataValidator(BaseModel): class WatsonXProviderDataValidator(BaseModel):
url: str model_config = ConfigDict(
api_key: str from_attributes=True,
project_id: str extra="forbid",
)
watsonx_api_key: str | None
@json_schema_type @json_schema_type
@ -25,13 +27,17 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"), default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai", description="A base url for accessing the watsonx.ai",
) )
# This seems like it should be required, but none of the other remote inference
# providers require it, so this is optional here too for consistency.
# The OpenAIConfig uses default=None instead, so this is following that precedent.
api_key: SecretStr | None = Field( api_key: SecretStr | None = Field(
default_factory=lambda: os.getenv("WATSONX_API_KEY"), default=None,
description="The watsonx API key", description="The watsonx.ai API key",
) )
# As above, this is optional here too for consistency.
project_id: str | None = Field( project_id: str | None = Field(
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"), default=None,
description="The Project ID key", description="The watsonx.ai project ID",
) )
timeout: int = Field( timeout: int = Field(
default=60, default=60,

View file

@ -1,47 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/llama-3-3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-2-13b-chat",
CoreModelId.llama2_13b.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]

View file

@ -4,240 +4,120 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
from ibm_watsonx_ai.foundation_models import Model import requests
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI
from llama_stack.apis.inference import ( from llama_stack.apis.inference import ChatCompletionRequest
ChatCompletionRequest, from llama_stack.apis.models import Model
CompletionRequest, from llama_stack.apis.models.models import ModelType
GreedySamplingStrategy, from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
Inference, from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
request_has_media,
)
from . import WatsonXConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::watsonx")
# Note on structured output class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
# WatsonX returns responses with a json embedded into a string. _model_cache: dict[str, Model] = {}
# Examples:
# ChatCompletionResponse(completion_message=CompletionMessage(content='```json\n{\n def __init__(self, config: WatsonXConfig):
# "first_name": "Michael",\n "last_name": "Jordan",\n'...) LiteLLMOpenAIMixin.__init__(
# Not even a valid JSON, but we can still extract the JSON from the content self,
litellm_provider_name="watsonx",
api_key_from_config=config.api_key.get_secret_value() if config.api_key else None,
provider_data_api_key_field="watsonx_api_key",
)
self.available_models = None
self.config = config
# CompletionResponse(content=' \nThe best answer is $\\boxed{\\{"name": "Michael Jordan", def get_base_url(self) -> str:
# "year_born": "1963", "year_retired": "2003"\\}}$') return self.config.url
# Find the start of the boxed content
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): # Add watsonx.ai specific parameters
def __init__(self, config: WatsonXConfig) -> None: params["project_id"] = self.config.project_id
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) params["time_limit"] = self.config.timeout
logger.info(f"Initializing watsonx InferenceAdapter({config.url})...")
self._config = config
self._openai_client: AsyncOpenAI | None = None
self._project_id = self._config.project_id
def _get_client(self, model_id) -> Model:
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
config_url = self._config.url
project_id = self._config.project_id
credentials = {"url": config_url, "apikey": config_api_key}
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
self._openai_client = AsyncOpenAI(
base_url=f"{self._config.url}/openai/v1",
api_key=self._config.api_key,
)
return self._openai_client
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
input_dict = {"params": {}}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
if request.sampling_params:
if request.sampling_params.strategy:
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
if request.sampling_params.max_tokens:
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["params"][GenParams.TEMPERATURE] = 0.0
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
params = {
**input_dict,
}
return params return params
async def openai_embeddings( # Copied from OpenAIMixin
self, async def check_model_availability(self, model: str) -> bool:
model: str, """
input: str | list[str], Check if a specific model is available from the provider's /v1/models.
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion( :param model: The model identifier to check.
self, :return: True if the model is available dynamically, False otherwise.
model: str, """
prompt: str | list[str] | list[int] | list[list[int]], if not self._model_cache:
best_of: int | None = None, await self.list_models()
echo: bool | None = None, return model in self._model_cache
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params) # type: ignore
async def openai_chat_completion( async def list_models(self) -> list[Model] | None:
self, self._model_cache = {}
model: str, models = []
messages: list[OpenAIMessageParam], for model_spec in self._get_model_specs():
frequency_penalty: float | None = None, functions = [f["id"] for f in model_spec.get("functions", [])]
function_call: str | dict[str, Any] | None = None, # Format: {"embedding_dimension": 1536, "context_length": 8192}
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator: # Example of an embedding model:
# watsonx.ai sometimes adds usage data to the stream # {'model_id': 'ibm/granite-embedding-278m-multilingual',
include_usage = False # 'label': 'granite-embedding-278m-multilingual',
if params.get("stream_options", None): # 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
include_usage = params["stream_options"].get("include_usage", False) # ...
stream = await self._get_openai_client().chat.completions.create(**params) provider_resource_id = f"{self.__provider_id__}/{model_spec['model_id']}"
if "embedding" in functions:
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
context_length = model_spec["model_limits"]["max_sequence_length"]
embedding_metadata = {
"embedding_dimension": embedding_dimension,
"context_length": context_length,
}
model = Model(
identifier=model_spec["model_id"],
provider_resource_id=provider_resource_id,
provider_id=self.__provider_id__,
metadata=embedding_metadata,
model_type=ModelType.embedding,
)
self._model_cache[provider_resource_id] = model
models.append(model)
if "text_chat" in functions:
model = Model(
identifier=model_spec["model_id"],
provider_resource_id=provider_resource_id,
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
# In theory, I guess it is possible that a model could be both an embedding model and a text chat model.
# In that case, the cache will record the generator Model object, and the list which we return will have
# both the generator Model object and the text chat Model object. That's fine because the cache is
# only used for check_model_availability() anyway.
self._model_cache[provider_resource_id] = model
models.append(model)
return models
seen_finish_reason = False # LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
async for chunk in stream: # So we need to implement our own method to list models by calling the watsonx.ai API.
# Final usage chunk with no choices that the user didn't request, so discard def _get_model_specs(self) -> list[dict[str, Any]]:
if not include_usage and seen_finish_reason and len(chunk.choices) == 0: """
break Retrieves foundation model specifications from the watsonx.ai API.
yield chunk """
for choice in chunk.choices: url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
if choice.finish_reason: headers = {
seen_finish_reason = True # Note that there is no authorization header. Listing models does not require authentication.
break "Content-Type": "application/json",
}
response = requests.get(url, headers=headers)
# --- Process the Response ---
# Raise an exception for bad status codes (4xx or 5xx)
response.raise_for_status()
# If the request is successful, parse and return the JSON response.
# The response should contain a list of model specifications
response_data = response.json()
if "resources" not in response_data:
raise ValueError("Resources not found in response")
return response_data["resources"]

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import struct
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any from typing import Any
@ -16,6 +18,7 @@ from llama_stack.apis.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
OpenAIMessageParam, OpenAIMessageParam,
@ -26,7 +29,6 @@ from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
b64_encode_openai_embeddings_response,
convert_message_to_openai_dict_new, convert_message_to_openai_dict_new,
convert_tooldef_to_openai_tool, convert_tooldef_to_openai_tool,
get_sampling_options, get_sampling_options,
@ -349,3 +351,28 @@ class LiteLLMOpenAIMixin(
return False return False
return model in litellm.models_by_provider[self.litellm_provider_name] return model in litellm.models_by_provider[self.litellm_provider_name]
def b64_encode_openai_embeddings_response(
response_data: list[dict], encoding_format: str | None = "float"
) -> list[OpenAIEmbeddingData]:
"""
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
"""
data = []
for i, embedding_data in enumerate(response_data):
if encoding_format == "base64":
byte_array = bytearray()
for embedding_value in embedding_data["embedding"]:
byte_array.extend(struct.pack("f", float(embedding_value)))
response_embedding = base64.b64encode(byte_array).decode("utf-8")
else:
response_embedding = embedding_data["embedding"]
data.append(
OpenAIEmbeddingData(
embedding=response_embedding,
index=i,
)
)
return data

View file

@ -3,9 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import json import json
import struct
import time import time
import uuid import uuid
import warnings import warnings
@ -103,7 +101,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
Message, Message,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIEmbeddingData,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
SamplingParams, SamplingParams,
@ -1402,28 +1399,3 @@ def prepare_openai_embeddings_params(
params["user"] = user params["user"] = user
return params return params
def b64_encode_openai_embeddings_response(
response_data: dict, encoding_format: str | None = "float"
) -> list[OpenAIEmbeddingData]:
"""
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
"""
data = []
for i, embedding_data in enumerate(response_data):
if encoding_format == "base64":
byte_array = bytearray()
for embedding_value in embedding_data.embedding:
byte_array.extend(struct.pack("f", float(embedding_value)))
response_embedding = base64.b64encode(byte_array).decode("utf-8")
else:
response_embedding = embedding_data.embedding
data.append(
OpenAIEmbeddingData(
embedding=response_embedding,
index=i,
)
)
return data

View file

@ -18,6 +18,8 @@ from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInferenceAdapter
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -58,3 +60,29 @@ def test_openai_provider_data_used(config_cls, adapter_cls, provider_data_valida
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
): ):
assert inference_adapter.client.api_key == api_key assert inference_adapter.client.api_key == api_key
@pytest.mark.parametrize(
"config_cls,adapter_cls,provider_data_validator",
[
(
WatsonXConfig,
WatsonXInferenceAdapter,
"llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
),
],
)
def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_validator: str):
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
assumption that there is an OpenAI-compatible client object."""
inference_adapter = adapter_cls(config=config_cls())
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
assert inference_adapter.get_api_key() == api_key