fix(watsonx.ai): Allows calling model with only space_id

Fixes a bug with the implementation of the watsonx.ai provider where
a user is required to provide a project_id if they already provide
a space_id.

If a space_id is provided, no project_id should be required.
This commit is contained in:
Eric Marcoux 2025-04-13 11:40:59 -04:00
parent 64bb89c70f
commit 97b8faf4f9
2 changed files with 73 additions and 3 deletions

View file

@ -114,10 +114,10 @@ def _get_api_params(
or get_secret_str("SPACE_ID")
)
if project_id is None:
if project_id is None and space_id is None:
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
message="Error: At least one of project_id or space_id must be set. Set WX_PROJECT_ID or WX_SPACE_ID in environment variables or pass in as a parameter.",
)
return WatsonXAPIParams(
@ -279,13 +279,13 @@ class IBMWatsonXMixin:
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
payload: dict = {}
payload["space_id"] = api_params["space_id"]
if model.startswith("deployment/"):
if api_params["space_id"] is None:
raise WatsonXAIError(
status_code=401,
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
payload["space_id"] = api_params["space_id"]
return payload
payload["model_id"] = model
payload["project_id"] = api_params["project_id"]

View file

@ -0,0 +1,70 @@
from unittest.mock import patch
import pytest
from litellm.llms.watsonx.common_utils import WatsonXAIError
@pytest.mark.parametrize(
"value_from",
[
{'params': {'space_id': 'space-id'}, 'secrets': {}},
{'params': {}, 'secrets': {'WATSONX_DEPLOYMENT_SPACE_ID': 'space-id'}},
{'params': {}, 'secrets': {'WATSONX_SPACE_ID': 'space-id'}},
{'params': {}, 'secrets': {'SPACE_ID': 'space-id'}},
],
ids=[
"field",
"WATSONX_DEPLOYMENT_SPACE_ID",
"WATSONX_SPACE_ID",
"SPACE_ID"
]
)
def test_watsonx_chat_handler_with_space_id(value_from):
with patch('litellm.llms.watsonx.common_utils.get_secret_str') as get_secret_str:
get_secret_str.return_value = ''
get_secret_str.side_effect = lambda key: value_from['secrets'][key] if key in value_from['secrets'] else None
from litellm.llms.watsonx import common_utils
_get_api_params = getattr(common_utils, '_get_api_params')
params = _get_api_params(value_from['params'])
assert params['space_id'] == 'space-id'
@pytest.mark.parametrize(
"value_from",
[
{'params': {'project_id': 'project-id'}, 'secrets': {}},
{'params': {}, 'secrets': {'WATSONX_PROJECT_ID': 'project-id'}},
{'params': {}, 'secrets': {'WX_PROJECT_ID': 'project-id'}},
{'params': {}, 'secrets': {'PROJECT_ID': 'project-id'}},
],
ids=[
"field",
"WATSONX_PROJECT_ID",
"WX_PROJECT_ID",
"PROJECT_ID"
]
)
def test_watsonx_chat_handler_with_project_id(value_from):
with patch('litellm.llms.watsonx.common_utils.get_secret_str') as get_secret_str:
get_secret_str.return_value = ''
get_secret_str.side_effect = lambda key: value_from['secrets'][key] if key in value_from['secrets'] else None
from litellm.llms.watsonx import common_utils
_get_api_params = getattr(common_utils, '_get_api_params')
params = _get_api_params(value_from['params'])
assert params['project_id'] == 'project-id'
def test_watsonx_chat_handler_with_no_space_id_or_project_id():
with patch('litellm.llms.watsonx.common_utils.get_secret_str') as get_secret_str:
get_secret_str.return_value = None
from litellm.llms.watsonx import common_utils
_get_api_params = getattr(common_utils, '_get_api_params')
with pytest.raises(WatsonXAIError) as error:
params = _get_api_params({})
assert 'At least one of project_id or space_id must be set.' in error.value.message