mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge 04f2c5f723
into b82af5b826
This commit is contained in:
commit
0373608c98
6 changed files with 65 additions and 23 deletions
|
@ -40,7 +40,7 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
|
|||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
fake_stream: bool = False,
|
||||
):
|
||||
api_params = _get_api_params(params=optional_params)
|
||||
api_params = _get_api_params(params=optional_params, model=model)
|
||||
|
||||
## UPDATE HEADERS
|
||||
headers = watsonx_chat_transformation.validate_environment(
|
||||
|
|
|
@ -76,9 +76,7 @@ def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str
|
|||
return token
|
||||
|
||||
|
||||
def _get_api_params(
|
||||
params: dict,
|
||||
) -> WatsonXAPIParams:
|
||||
def _get_api_params(params: dict, model: Optional[str] = None) -> WatsonXAPIParams:
|
||||
"""
|
||||
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
||||
"""
|
||||
|
@ -114,10 +112,15 @@ def _get_api_params(
|
|||
or get_secret_str("SPACE_ID")
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
if (
|
||||
project_id is None
|
||||
and space_id is None
|
||||
and model is not None
|
||||
and not model.startswith("deployment/")
|
||||
):
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
|
||||
message="Error: Watsonx project_id and space_id not set. Set WX_PROJECT_ID or WX_SPACE_ID in environment variables or pass in as a parameter.",
|
||||
)
|
||||
|
||||
return WatsonXAPIParams(
|
||||
|
@ -280,13 +283,10 @@ class IBMWatsonXMixin:
|
|||
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
|
||||
payload: dict = {}
|
||||
if model.startswith("deployment/"):
|
||||
if api_params["space_id"] is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||
)
|
||||
payload["space_id"] = api_params["space_id"]
|
||||
return payload
|
||||
payload["model_id"] = model
|
||||
payload["project_id"] = api_params["project_id"]
|
||||
if api_params["project_id"] is not None:
|
||||
payload["project_id"] = api_params["project_id"]
|
||||
else:
|
||||
payload["space_id"] = api_params["space_id"]
|
||||
return payload
|
||||
|
|
|
@ -245,7 +245,7 @@ class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig):
|
|||
)
|
||||
extra_body_params = optional_params.pop("extra_body", {})
|
||||
optional_params.update(extra_body_params)
|
||||
watsonx_api_params = _get_api_params(params=optional_params)
|
||||
watsonx_api_params = _get_api_params(params=optional_params, model=model)
|
||||
|
||||
watsonx_auth_payload = self._prepare_payload(
|
||||
model=model,
|
||||
|
|
|
@ -37,7 +37,7 @@ class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
|
|||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
watsonx_api_params = _get_api_params(params=optional_params)
|
||||
watsonx_api_params = _get_api_params(params=optional_params, model=model)
|
||||
watsonx_auth_payload = self._prepare_payload(
|
||||
model=model,
|
||||
api_params=watsonx_api_params,
|
||||
|
|
|
@ -6,7 +6,7 @@ from pydantic import BaseModel
|
|||
|
||||
|
||||
class WatsonXAPIParams(TypedDict):
|
||||
project_id: str
|
||||
project_id: Optional[str]
|
||||
space_id: Optional[str]
|
||||
region_name: Optional[str]
|
||||
|
||||
|
|
|
@ -1,17 +1,14 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import completion, embedding
|
||||
from litellm.llms.watsonx.common_utils import IBMWatsonXMixin
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
from unittest.mock import patch, MagicMock, AsyncMock, Mock
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from unittest.mock import patch, Mock
|
||||
import pytest
|
||||
from typing import Optional
|
||||
|
||||
|
@ -188,6 +185,27 @@ def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call):
|
|||
assert "deployment" not in mock_post.call_args.kwargs["url"]
|
||||
|
||||
|
||||
def test_watsonx_chat_completions_endpoint_space_id(
|
||||
monkeypatch, watsonx_chat_completion_call
|
||||
):
|
||||
my_fake_space_id = "xxx-xxx-xxx-xxx-xxx"
|
||||
monkeypatch.setenv("WATSONX_SPACE_ID", my_fake_space_id)
|
||||
|
||||
monkeypatch.delenv("WATSONX_PROJECT_ID")
|
||||
|
||||
model = "watsonx/another-model"
|
||||
messages = [{"role": "user", "content": "Test message"}]
|
||||
|
||||
mock_post, _ = watsonx_chat_completion_call(model=model, messages=messages)
|
||||
|
||||
assert mock_post.call_count == 1
|
||||
assert "deployment" not in mock_post.call_args.kwargs["url"]
|
||||
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert my_fake_space_id == json_data["space_id"]
|
||||
assert not json_data.get("project_id")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
|
@ -206,7 +224,29 @@ def test_watsonx_deployment_space_id(monkeypatch, watsonx_chat_completion_call,
|
|||
|
||||
assert mock_post.call_count == 1
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert my_fake_space_id == json_data["space_id"]
|
||||
assert "space_id" not in json_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"watsonx/deployment/<xxxx.xxx.xxx.xxxx>",
|
||||
"watsonx_text/deployment/<xxxx.xxx.xxx.xxxx>",
|
||||
],
|
||||
)
|
||||
def test_watsonx_deployment(watsonx_chat_completion_call, model):
|
||||
|
||||
messages = [{"content": "Hello, how are you?", "role": "user"}]
|
||||
mock_post, _ = watsonx_chat_completion_call(
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
assert mock_post.call_count == 1
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
|
||||
# nor space_id or project_id is required by wx.ai API when inferencing deployment
|
||||
assert "project_id" not in json_data and "space_id" not in json_data
|
||||
|
||||
|
||||
def test_watsonx_deployment_space_id_embedding(monkeypatch, watsonx_embedding_call):
|
||||
|
@ -217,4 +257,6 @@ def test_watsonx_deployment_space_id_embedding(monkeypatch, watsonx_embedding_ca
|
|||
|
||||
assert mock_post.call_count == 1
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert my_fake_space_id == json_data["space_id"]
|
||||
|
||||
# nor space_id or project_id is required by wx.ai API when inferencing deployment
|
||||
assert "project_id" not in json_data and "space_id" not in json_data
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue