diff --git a/litellm/llms/watsonx/common_utils.py b/litellm/llms/watsonx/common_utils.py index d6f296c608..299fc843b3 100644 --- a/litellm/llms/watsonx/common_utils.py +++ b/litellm/llms/watsonx/common_utils.py @@ -114,10 +114,10 @@ def _get_api_params( or get_secret_str("SPACE_ID") ) - if project_id is None: + if project_id is None and space_id is None: raise WatsonXAIError( status_code=401, - message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.", + message="Error: At least one of project_id or space_id must be set. Set WX_PROJECT_ID or WX_SPACE_ID in environment variables or pass in as a parameter.", ) return WatsonXAPIParams( @@ -279,13 +279,13 @@ class IBMWatsonXMixin: def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict: payload: dict = {} + payload["space_id"] = api_params["space_id"] if model.startswith("deployment/"): if api_params["space_id"] is None: raise WatsonXAIError( status_code=401, message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.", ) - payload["space_id"] = api_params["space_id"] return payload payload["model_id"] = model payload["project_id"] = api_params["project_id"] diff --git a/tests/litellm/llms/watsonx/test_common_util.py b/tests/litellm/llms/watsonx/test_common_util.py new file mode 100644 index 0000000000..d248542deb --- /dev/null +++ b/tests/litellm/llms/watsonx/test_common_util.py @@ -0,0 +1,70 @@ +from unittest.mock import patch +import pytest + +from litellm.llms.watsonx.common_utils import WatsonXAIError + + +@pytest.mark.parametrize( + "value_from", + [ + {'params': {'space_id': 'space-id'}, 'secrets': {}}, + {'params': {}, 'secrets': {'WATSONX_DEPLOYMENT_SPACE_ID': 'space-id'}}, + {'params': {}, 'secrets': {'WATSONX_SPACE_ID': 'space-id'}}, + {'params': {}, 'secrets': {'SPACE_ID': 'space-id'}}, + ], + ids=[ + "field", + "WATSONX_DEPLOYMENT_SPACE_ID", + "WATSONX_SPACE_ID", + "SPACE_ID" + ] +) +def test_watsonx_chat_handler_with_space_id(value_from): + with patch('litellm.llms.watsonx.common_utils.get_secret_str') as get_secret_str: + get_secret_str.return_value = '' + get_secret_str.side_effect = lambda key: value_from['secrets'][key] if key in value_from['secrets'] else None + + from litellm.llms.watsonx import common_utils + + _get_api_params = getattr(common_utils, '_get_api_params') + params = _get_api_params(value_from['params']) + assert params['space_id'] == 'space-id' + + +@pytest.mark.parametrize( + "value_from", + [ + {'params': {'project_id': 'project-id'}, 'secrets': {}}, + {'params': {}, 'secrets': {'WATSONX_PROJECT_ID': 'project-id'}}, + {'params': {}, 'secrets': {'WX_PROJECT_ID': 'project-id'}}, + {'params': {}, 'secrets': {'PROJECT_ID': 'project-id'}}, + ], + ids=[ + "field", + "WATSONX_PROJECT_ID", + "WX_PROJECT_ID", + "PROJECT_ID" + ] +) +def test_watsonx_chat_handler_with_project_id(value_from): + with patch('litellm.llms.watsonx.common_utils.get_secret_str') as get_secret_str: + get_secret_str.return_value = '' + get_secret_str.side_effect = lambda key: value_from['secrets'][key] if key in value_from['secrets'] else None + + from litellm.llms.watsonx import common_utils + + _get_api_params = getattr(common_utils, '_get_api_params') + params = _get_api_params(value_from['params']) + assert params['project_id'] == 'project-id' + + +def test_watsonx_chat_handler_with_no_space_id_or_project_id(): + with patch('litellm.llms.watsonx.common_utils.get_secret_str') as get_secret_str: + get_secret_str.return_value = None + + from litellm.llms.watsonx import common_utils + + _get_api_params = getattr(common_utils, '_get_api_params') + with pytest.raises(WatsonXAIError) as error: + params = _get_api_params({}) + assert 'At least one of project_id or space_id must be set.' in error.value.message