mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge 97b8faf4f9
into b82af5b826
This commit is contained in:
commit
743a1903a5
2 changed files with 73 additions and 3 deletions
|
@ -114,10 +114,10 @@ def _get_api_params(
|
|||
or get_secret_str("SPACE_ID")
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
if project_id is None and space_id is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
|
||||
message="Error: At least one of project_id or space_id must be set. Set WX_PROJECT_ID or WX_SPACE_ID in environment variables or pass in as a parameter.",
|
||||
)
|
||||
|
||||
return WatsonXAPIParams(
|
||||
|
@ -279,13 +279,13 @@ class IBMWatsonXMixin:
|
|||
|
||||
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
|
||||
payload: dict = {}
|
||||
payload["space_id"] = api_params["space_id"]
|
||||
if model.startswith("deployment/"):
|
||||
if api_params["space_id"] is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||
)
|
||||
payload["space_id"] = api_params["space_id"]
|
||||
return payload
|
||||
payload["model_id"] = model
|
||||
payload["project_id"] = api_params["project_id"]
|
||||
|
|
70
tests/litellm/llms/watsonx/test_common_util.py
Normal file
70
tests/litellm/llms/watsonx/test_common_util.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from litellm.llms.watsonx.common_utils import WatsonXAIError
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value_from",
|
||||
[
|
||||
{'params': {'space_id': 'space-id'}, 'secrets': {}},
|
||||
{'params': {}, 'secrets': {'WATSONX_DEPLOYMENT_SPACE_ID': 'space-id'}},
|
||||
{'params': {}, 'secrets': {'WATSONX_SPACE_ID': 'space-id'}},
|
||||
{'params': {}, 'secrets': {'SPACE_ID': 'space-id'}},
|
||||
],
|
||||
ids=[
|
||||
"field",
|
||||
"WATSONX_DEPLOYMENT_SPACE_ID",
|
||||
"WATSONX_SPACE_ID",
|
||||
"SPACE_ID"
|
||||
]
|
||||
)
|
||||
def test_watsonx_chat_handler_with_space_id(value_from):
|
||||
with patch('litellm.llms.watsonx.common_utils.get_secret_str') as get_secret_str:
|
||||
get_secret_str.return_value = ''
|
||||
get_secret_str.side_effect = lambda key: value_from['secrets'][key] if key in value_from['secrets'] else None
|
||||
|
||||
from litellm.llms.watsonx import common_utils
|
||||
|
||||
_get_api_params = getattr(common_utils, '_get_api_params')
|
||||
params = _get_api_params(value_from['params'])
|
||||
assert params['space_id'] == 'space-id'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value_from",
|
||||
[
|
||||
{'params': {'project_id': 'project-id'}, 'secrets': {}},
|
||||
{'params': {}, 'secrets': {'WATSONX_PROJECT_ID': 'project-id'}},
|
||||
{'params': {}, 'secrets': {'WX_PROJECT_ID': 'project-id'}},
|
||||
{'params': {}, 'secrets': {'PROJECT_ID': 'project-id'}},
|
||||
],
|
||||
ids=[
|
||||
"field",
|
||||
"WATSONX_PROJECT_ID",
|
||||
"WX_PROJECT_ID",
|
||||
"PROJECT_ID"
|
||||
]
|
||||
)
|
||||
def test_watsonx_chat_handler_with_project_id(value_from):
|
||||
with patch('litellm.llms.watsonx.common_utils.get_secret_str') as get_secret_str:
|
||||
get_secret_str.return_value = ''
|
||||
get_secret_str.side_effect = lambda key: value_from['secrets'][key] if key in value_from['secrets'] else None
|
||||
|
||||
from litellm.llms.watsonx import common_utils
|
||||
|
||||
_get_api_params = getattr(common_utils, '_get_api_params')
|
||||
params = _get_api_params(value_from['params'])
|
||||
assert params['project_id'] == 'project-id'
|
||||
|
||||
|
||||
def test_watsonx_chat_handler_with_no_space_id_or_project_id():
|
||||
with patch('litellm.llms.watsonx.common_utils.get_secret_str') as get_secret_str:
|
||||
get_secret_str.return_value = None
|
||||
|
||||
from litellm.llms.watsonx import common_utils
|
||||
|
||||
_get_api_params = getattr(common_utils, '_get_api_params')
|
||||
with pytest.raises(WatsonXAIError) as error:
|
||||
params = _get_api_params({})
|
||||
assert 'At least one of project_id or space_id must be set.' in error.value.message
|
Loading…
Add table
Add a link
Reference in a new issue