mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Litellm dev 12 25 2024 p1 (#7411)
* test(test_watsonx.py): e2e unit test for watsonx custom header covers https://github.com/BerriAI/litellm/issues/7408 * fix(common_utils.py): handle auth token already present in headers (watsonx + openai-like base handler) Fixes https://github.com/BerriAI/litellm/issues/7408 * fix(watsonx/chat): fix chat route Fixes https://github.com/BerriAI/litellm/issues/7408 * fix(huggingface/chat/handler.py): fix huggingface async completion calls * Correct handling of max_retries=0 to disable AzureOpenAI retries (#7379) * test: fix test --------- Co-authored-by: Minh Duc <phamminhduc0711@gmail.com>
This commit is contained in:
parent
157810fcbf
commit
9237357bcc
9 changed files with 299 additions and 14 deletions
|
@ -855,7 +855,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
self._client_session = self.create_client_session()
|
self._client_session = self.create_client_session()
|
||||||
try:
|
try:
|
||||||
data = {"model": model, "input": input, **optional_params}
|
data = {"model": model, "input": input, **optional_params}
|
||||||
max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES
|
if max_retries is None:
|
||||||
|
max_retries = litellm.DEFAULT_MAX_RETRIES
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(
|
||||||
status_code=422, message="max retries must be an int"
|
status_code=422, message="max retries must be an int"
|
||||||
|
|
|
@ -203,7 +203,26 @@ class Huggingface(BaseLLM):
|
||||||
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout, messages=messages) # type: ignore
|
return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout, messages=messages) # type: ignore
|
||||||
else:
|
else:
|
||||||
### ASYNC COMPLETION
|
### ASYNC COMPLETION
|
||||||
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, model=model, optional_params=optional_params, timeout=timeout, litellm_params=litellm_params) # type: ignore
|
return self.acompletion(
|
||||||
|
api_base=completion_url,
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
model_response=model_response,
|
||||||
|
encoding=encoding,
|
||||||
|
model=model,
|
||||||
|
optional_params=optional_params,
|
||||||
|
timeout=timeout,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
messages=messages,
|
||||||
|
client=(
|
||||||
|
client
|
||||||
|
if client is not None
|
||||||
|
and isinstance(client, AsyncHTTPHandler)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
if client is None or not isinstance(client, HTTPHandler):
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
client = _get_httpx_client()
|
client = _get_httpx_client()
|
||||||
### SYNC STREAMING
|
### SYNC STREAMING
|
||||||
|
@ -267,14 +286,16 @@ class Huggingface(BaseLLM):
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
messages: List[AllMessageValues],
|
messages: List[AllMessageValues],
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
):
|
):
|
||||||
response: Optional[httpx.Response] = None
|
response: Optional[httpx.Response] = None
|
||||||
try:
|
try:
|
||||||
http_client = get_async_httpx_client(
|
if client is None:
|
||||||
|
client = get_async_httpx_client(
|
||||||
llm_provider=litellm.LlmProviders.HUGGINGFACE
|
llm_provider=litellm.LlmProviders.HUGGINGFACE
|
||||||
)
|
)
|
||||||
### ASYNC COMPLETION
|
### ASYNC COMPLETION
|
||||||
http_response = await http_client.post(
|
http_response = await client.post(
|
||||||
url=api_base, headers=headers, data=json.dumps(data), timeout=timeout
|
url=api_base, headers=headers, data=json.dumps(data), timeout=timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -365,7 +365,7 @@ class OpenAILikeChatHandler(OpenAILikeBase):
|
||||||
client = HTTPHandler(timeout=timeout) # type: ignore
|
client = HTTPHandler(timeout=timeout) # type: ignore
|
||||||
try:
|
try:
|
||||||
response = client.post(
|
response = client.post(
|
||||||
api_base, headers=headers, data=json.dumps(data)
|
url=api_base, headers=headers, data=json.dumps(data)
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,9 @@ class OpenAILikeBase:
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
if api_key is not None:
|
if (
|
||||||
|
api_key is not None and "Authorization" not in headers
|
||||||
|
): # [TODO] remove 'validate_environment' from OpenAI base. should use llm providers config for this only.
|
||||||
headers.update({"Authorization": "Bearer {}".format(api_key)})
|
headers.update({"Authorization": "Bearer {}".format(api_key)})
|
||||||
|
|
||||||
if not custom_endpoint:
|
if not custom_endpoint:
|
||||||
|
|
|
@ -102,9 +102,9 @@ class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
|
||||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||||
else:
|
else:
|
||||||
endpoint = (
|
endpoint = (
|
||||||
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
WatsonXAIEndpoint.CHAT_STREAM.value
|
||||||
if stream
|
if stream
|
||||||
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
|
else WatsonXAIEndpoint.CHAT.value
|
||||||
)
|
)
|
||||||
url = url.rstrip("/") + endpoint
|
url = url.rstrip("/") + endpoint
|
||||||
|
|
||||||
|
|
|
@ -53,7 +53,9 @@ def generate_iam_token(api_key=None, **params) -> str:
|
||||||
headers,
|
headers,
|
||||||
data,
|
data,
|
||||||
)
|
)
|
||||||
response = httpx.post(iam_token_url, data=data, headers=headers)
|
response = litellm.module_level_client.post(
|
||||||
|
url=iam_token_url, data=data, headers=headers
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
json_data = response.json()
|
json_data = response.json()
|
||||||
|
|
||||||
|
@ -165,10 +167,13 @@ class IBMWatsonXMixin:
|
||||||
optional_params: Dict,
|
optional_params: Dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
headers = {
|
default_headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if "Authorization" in headers:
|
||||||
|
return {**default_headers, **headers}
|
||||||
token = cast(Optional[str], optional_params.get("token"))
|
token = cast(Optional[str], optional_params.get("token"))
|
||||||
if token:
|
if token:
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
@ -176,7 +181,7 @@ class IBMWatsonXMixin:
|
||||||
token = _generate_watsonx_token(api_key=api_key, token=token)
|
token = _generate_watsonx_token(api_key=api_key, token=token)
|
||||||
# build auth headers
|
# build auth headers
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
return headers
|
return {**default_headers, **headers}
|
||||||
|
|
||||||
def _get_base_url(self, api_base: Optional[str]) -> str:
|
def _get_base_url(self, api_base: Optional[str]) -> str:
|
||||||
url = (
|
url = (
|
||||||
|
|
|
@ -278,7 +278,7 @@ def test_completions_with_sync_http_handler(monkeypatch):
|
||||||
assert response.to_dict() == expected_response_json
|
assert response.to_dict() == expected_response_json
|
||||||
|
|
||||||
mock_post.assert_called_once_with(
|
mock_post.assert_called_once_with(
|
||||||
f"{base_url}/chat/completions",
|
url=f"{base_url}/chat/completions",
|
||||||
headers={
|
headers={
|
||||||
"Authorization": f"Bearer {api_key}",
|
"Authorization": f"Bearer {api_key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
|
169
tests/llm_translation/test_huggingface.py
Normal file
169
tests/llm_translation/test_huggingface.py
Normal file
|
@ -0,0 +1,169 @@
|
||||||
|
"""
|
||||||
|
Unit Tests Huggingface route
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from litellm import completion, acompletion
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock, Mock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def tgi_mock_post(url, **kwargs):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json.return_value = [
|
||||||
|
{
|
||||||
|
"generated_text": "<|assistant|>\nI'm",
|
||||||
|
"details": {
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"seed": None,
|
||||||
|
"prefill": [],
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 28789,
|
||||||
|
"text": "<",
|
||||||
|
"logprob": -0.025222778,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28766,
|
||||||
|
"text": "|",
|
||||||
|
"logprob": -0.000003695488,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 489,
|
||||||
|
"text": "ass",
|
||||||
|
"logprob": -0.0000019073486,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11143,
|
||||||
|
"text": "istant",
|
||||||
|
"logprob": -0.000002026558,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28766,
|
||||||
|
"text": "|",
|
||||||
|
"logprob": -0.0000015497208,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28767,
|
||||||
|
"text": ">",
|
||||||
|
"logprob": -0.0000011920929,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"text": "\n",
|
||||||
|
"logprob": -0.00009703636,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
|
||||||
|
{
|
||||||
|
"id": 28742,
|
||||||
|
"text": "'",
|
||||||
|
"logprob": -0.88183594,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28719,
|
||||||
|
"text": "m",
|
||||||
|
"logprob": -0.00032639503,
|
||||||
|
"special": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def huggingface_chat_completion_call():
|
||||||
|
def _call(
|
||||||
|
model="huggingface/my-test-model",
|
||||||
|
messages=None,
|
||||||
|
api_key="test_api_key",
|
||||||
|
headers=None,
|
||||||
|
client=None,
|
||||||
|
):
|
||||||
|
if messages is None:
|
||||||
|
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
||||||
|
if client is None:
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
|
||||||
|
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_post:
|
||||||
|
completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers or {},
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
return mock_post
|
||||||
|
|
||||||
|
return _call
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def async_huggingface_chat_completion_call():
|
||||||
|
async def _call(
|
||||||
|
model="huggingface/my-test-model",
|
||||||
|
messages=None,
|
||||||
|
api_key="test_api_key",
|
||||||
|
headers=None,
|
||||||
|
client=None,
|
||||||
|
):
|
||||||
|
if messages is None:
|
||||||
|
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
||||||
|
if client is None:
|
||||||
|
client = AsyncHTTPHandler()
|
||||||
|
|
||||||
|
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_post:
|
||||||
|
await acompletion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers or {},
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
return mock_post
|
||||||
|
|
||||||
|
return _call
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_huggingface_chat_completions_endpoint(
|
||||||
|
sync_mode, huggingface_chat_completion_call, async_huggingface_chat_completion_call
|
||||||
|
):
|
||||||
|
model = "huggingface/another-model"
|
||||||
|
messages = [{"role": "user", "content": "Test message"}]
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
mock_post = huggingface_chat_completion_call(model=model, messages=messages)
|
||||||
|
else:
|
||||||
|
mock_post = await async_huggingface_chat_completion_call(
|
||||||
|
model=model, messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_post.call_count == 1
|
87
tests/llm_translation/test_watsonx.py
Normal file
87
tests/llm_translation/test_watsonx.py
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from litellm import completion
|
||||||
|
from litellm.llms.watsonx.common_utils import IBMWatsonXMixin
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock, Mock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def watsonx_chat_completion_call():
|
||||||
|
def _call(
|
||||||
|
model="watsonx/my-test-model",
|
||||||
|
messages=None,
|
||||||
|
api_key="test_api_key",
|
||||||
|
headers=None,
|
||||||
|
client=None,
|
||||||
|
):
|
||||||
|
if messages is None:
|
||||||
|
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
||||||
|
if client is None:
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"access_token": "mock_access_token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = Mock() # No-op to simulate no exception
|
||||||
|
|
||||||
|
with patch.object(client, "post") as mock_post, patch.object(
|
||||||
|
litellm.module_level_client, "post", return_value=mock_response
|
||||||
|
) as mock_get:
|
||||||
|
completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers or {},
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
return mock_post, mock_get
|
||||||
|
|
||||||
|
return _call
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("with_custom_auth_header", [True, False])
|
||||||
|
def test_watsonx_custom_auth_header(
|
||||||
|
with_custom_auth_header, watsonx_chat_completion_call
|
||||||
|
):
|
||||||
|
headers = (
|
||||||
|
{"Authorization": "Bearer my-custom-auth-header"}
|
||||||
|
if with_custom_auth_header
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_post, _ = watsonx_chat_completion_call(headers=headers)
|
||||||
|
|
||||||
|
assert mock_post.call_count == 1
|
||||||
|
if with_custom_auth_header:
|
||||||
|
assert (
|
||||||
|
mock_post.call_args[1]["headers"]["Authorization"]
|
||||||
|
== "Bearer my-custom-auth-header"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
mock_post.call_args[1]["headers"]["Authorization"]
|
||||||
|
== "Bearer mock_access_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call):
|
||||||
|
model = "watsonx/another-model"
|
||||||
|
messages = [{"role": "user", "content": "Test message"}]
|
||||||
|
|
||||||
|
mock_post, _ = watsonx_chat_completion_call(model=model, messages=messages)
|
||||||
|
|
||||||
|
assert mock_post.call_count == 1
|
||||||
|
assert "deployment" not in mock_post.call_args.kwargs["url"]
|
Loading…
Add table
Add a link
Reference in a new issue