This commit is contained in:
Mateusz Świtała 2025-04-24 00:55:02 -07:00 committed by GitHub
commit 0373608c98
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 65 additions and 23 deletions

View file

@ -40,7 +40,7 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
streaming_decoder: Optional[CustomStreamingDecoder] = None,
fake_stream: bool = False,
):
api_params = _get_api_params(params=optional_params)
api_params = _get_api_params(params=optional_params, model=model)
## UPDATE HEADERS
headers = watsonx_chat_transformation.validate_environment(

View file

@ -76,9 +76,7 @@ def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str
return token
def _get_api_params(
params: dict,
) -> WatsonXAPIParams:
def _get_api_params(params: dict, model: Optional[str] = None) -> WatsonXAPIParams:
"""
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
"""
@ -114,10 +112,15 @@ def _get_api_params(
or get_secret_str("SPACE_ID")
)
if project_id is None:
if (
project_id is None
and space_id is None
and model is not None
and not model.startswith("deployment/")
):
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
message="Error: Watsonx project_id and space_id not set. Set WX_PROJECT_ID or WX_SPACE_ID in environment variables or pass in as a parameter.",
)
return WatsonXAPIParams(
@ -280,13 +283,10 @@ class IBMWatsonXMixin:
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
payload: dict = {}
if model.startswith("deployment/"):
if api_params["space_id"] is None:
raise WatsonXAIError(
status_code=401,
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
payload["space_id"] = api_params["space_id"]
return payload
payload["model_id"] = model
payload["project_id"] = api_params["project_id"]
if api_params["project_id"] is not None:
payload["project_id"] = api_params["project_id"]
else:
payload["space_id"] = api_params["space_id"]
return payload

View file

@ -245,7 +245,7 @@ class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig):
)
extra_body_params = optional_params.pop("extra_body", {})
optional_params.update(extra_body_params)
watsonx_api_params = _get_api_params(params=optional_params)
watsonx_api_params = _get_api_params(params=optional_params, model=model)
watsonx_auth_payload = self._prepare_payload(
model=model,

View file

@ -37,7 +37,7 @@ class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
optional_params: dict,
headers: dict,
) -> dict:
watsonx_api_params = _get_api_params(params=optional_params)
watsonx_api_params = _get_api_params(params=optional_params, model=model)
watsonx_auth_payload = self._prepare_payload(
model=model,
api_params=watsonx_api_params,

View file

@ -6,7 +6,7 @@ from pydantic import BaseModel
class WatsonXAPIParams(TypedDict):
project_id: str
project_id: Optional[str]
space_id: Optional[str]
region_name: Optional[str]

View file

@ -1,17 +1,14 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import completion, embedding
from litellm.llms.watsonx.common_utils import IBMWatsonXMixin
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
from unittest.mock import patch, MagicMock, AsyncMock, Mock
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, Mock
import pytest
from typing import Optional
@ -188,6 +185,27 @@ def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call):
assert "deployment" not in mock_post.call_args.kwargs["url"]
def test_watsonx_chat_completions_endpoint_space_id(
monkeypatch, watsonx_chat_completion_call
):
my_fake_space_id = "xxx-xxx-xxx-xxx-xxx"
monkeypatch.setenv("WATSONX_SPACE_ID", my_fake_space_id)
monkeypatch.delenv("WATSONX_PROJECT_ID")
model = "watsonx/another-model"
messages = [{"role": "user", "content": "Test message"}]
mock_post, _ = watsonx_chat_completion_call(model=model, messages=messages)
assert mock_post.call_count == 1
assert "deployment" not in mock_post.call_args.kwargs["url"]
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert my_fake_space_id == json_data["space_id"]
assert not json_data.get("project_id")
@pytest.mark.parametrize(
"model",
[
@ -206,7 +224,29 @@ def test_watsonx_deployment_space_id(monkeypatch, watsonx_chat_completion_call,
assert mock_post.call_count == 1
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert my_fake_space_id == json_data["space_id"]
assert "space_id" not in json_data
@pytest.mark.parametrize(
"model",
[
"watsonx/deployment/<xxxx.xxx.xxx.xxxx>",
"watsonx_text/deployment/<xxxx.xxx.xxx.xxxx>",
],
)
def test_watsonx_deployment(watsonx_chat_completion_call, model):
messages = [{"content": "Hello, how are you?", "role": "user"}]
mock_post, _ = watsonx_chat_completion_call(
model=model,
messages=messages,
)
assert mock_post.call_count == 1
json_data = json.loads(mock_post.call_args.kwargs["data"])
# nor space_id or project_id is required by wx.ai API when inferencing deployment
assert "project_id" not in json_data and "space_id" not in json_data
def test_watsonx_deployment_space_id_embedding(monkeypatch, watsonx_embedding_call):
@ -217,4 +257,6 @@ def test_watsonx_deployment_space_id_embedding(monkeypatch, watsonx_embedding_ca
assert mock_post.call_count == 1
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert my_fake_space_id == json_data["space_id"]
# nor space_id or project_id is required by wx.ai API when inferencing deployment
assert "project_id" not in json_data and "space_id" not in json_data