From b961f96b35ed16e2f685109619ca2d08afbb55c0 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Wed, 25 Dec 2024 17:36:30 -0800 Subject: [PATCH] Litellm dev 12 25 2024 p1 (#7411) * test(test_watsonx.py): e2e unit test for watsonx custom header covers https://github.com/BerriAI/litellm/issues/7408 * fix(common_utils.py): handle auth token already present in headers (watsonx + openai-like base handler) Fixes https://github.com/BerriAI/litellm/issues/7408 * fix(watsonx/chat): fix chat route Fixes https://github.com/BerriAI/litellm/issues/7408 * fix(huggingface/chat/handler.py): fix huggingface async completion calls * Correct handling of max_retries=0 to disable AzureOpenAI retries (#7379) * test: fix test --------- Co-authored-by: Minh Duc --- litellm/llms/azure/azure.py | 3 +- litellm/llms/huggingface/chat/handler.py | 31 +++- litellm/llms/openai_like/chat/handler.py | 2 +- litellm/llms/openai_like/common_utils.py | 4 +- litellm/llms/watsonx/chat/transformation.py | 4 +- litellm/llms/watsonx/common_utils.py | 11 +- tests/llm_translation/test_databricks.py | 2 +- tests/llm_translation/test_huggingface.py | 169 ++++++++++++++++++++ tests/llm_translation/test_watsonx.py | 87 ++++++++++ 9 files changed, 299 insertions(+), 14 deletions(-) create mode 100644 tests/llm_translation/test_huggingface.py create mode 100644 tests/llm_translation/test_watsonx.py diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 837d425b82..f7110210d3 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -855,7 +855,8 @@ class AzureChatCompletion(BaseLLM): self._client_session = self.create_client_session() try: data = {"model": model, "input": input, **optional_params} - max_retries = max_retries or litellm.DEFAULT_MAX_RETRIES + if max_retries is None: + max_retries = litellm.DEFAULT_MAX_RETRIES if not isinstance(max_retries, int): raise AzureOpenAIError( status_code=422, message="max retries must be an int" diff --git a/litellm/llms/huggingface/chat/handler.py b/litellm/llms/huggingface/chat/handler.py index d357edf329..df3140e104 100644 --- a/litellm/llms/huggingface/chat/handler.py +++ b/litellm/llms/huggingface/chat/handler.py @@ -203,7 +203,26 @@ class Huggingface(BaseLLM): return self.async_streaming(logging_obj=logging_obj, api_base=completion_url, data=data, headers=headers, model_response=model_response, model=model, timeout=timeout, messages=messages) # type: ignore else: ### ASYNC COMPLETION - return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, model=model, optional_params=optional_params, timeout=timeout, litellm_params=litellm_params) # type: ignore + return self.acompletion( + api_base=completion_url, + data=data, + headers=headers, + model_response=model_response, + encoding=encoding, + model=model, + optional_params=optional_params, + timeout=timeout, + litellm_params=litellm_params, + logging_obj=logging_obj, + api_key=api_key, + messages=messages, + client=( + client + if client is not None + and isinstance(client, AsyncHTTPHandler) + else None + ), + ) if client is None or not isinstance(client, HTTPHandler): client = _get_httpx_client() ### SYNC STREAMING @@ -267,14 +286,16 @@ class Huggingface(BaseLLM): logging_obj: LiteLLMLoggingObj, api_key: str, messages: List[AllMessageValues], + client: Optional[AsyncHTTPHandler] = None, ): response: Optional[httpx.Response] = None try: - http_client = get_async_httpx_client( - llm_provider=litellm.LlmProviders.HUGGINGFACE - ) + if client is None: + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.HUGGINGFACE + ) ### ASYNC COMPLETION - http_response = await http_client.post( + http_response = await client.post( url=api_base, headers=headers, data=json.dumps(data), timeout=timeout ) diff --git a/litellm/llms/openai_like/chat/handler.py b/litellm/llms/openai_like/chat/handler.py index f190d37455..bd9635b086 100644 --- a/litellm/llms/openai_like/chat/handler.py +++ b/litellm/llms/openai_like/chat/handler.py @@ -365,7 +365,7 @@ class OpenAILikeChatHandler(OpenAILikeBase): client = HTTPHandler(timeout=timeout) # type: ignore try: response = client.post( - api_base, headers=headers, data=json.dumps(data) + url=api_base, headers=headers, data=json.dumps(data) ) response.raise_for_status() diff --git a/litellm/llms/openai_like/common_utils.py b/litellm/llms/openai_like/common_utils.py index 3051618d48..116277b6dd 100644 --- a/litellm/llms/openai_like/common_utils.py +++ b/litellm/llms/openai_like/common_utils.py @@ -43,7 +43,9 @@ class OpenAILikeBase: "Content-Type": "application/json", } - if api_key is not None: + if ( + api_key is not None and "Authorization" not in headers + ): # [TODO] remove 'validate_environment' from OpenAI base. should use llm providers config for this only. headers.update({"Authorization": "Bearer {}".format(api_key)}) if not custom_endpoint: diff --git a/litellm/llms/watsonx/chat/transformation.py b/litellm/llms/watsonx/chat/transformation.py index 19e98daaaa..5d0c432c56 100644 --- a/litellm/llms/watsonx/chat/transformation.py +++ b/litellm/llms/watsonx/chat/transformation.py @@ -102,9 +102,9 @@ class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig): endpoint = endpoint.format(deployment_id=deployment_id) else: endpoint = ( - WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value + WatsonXAIEndpoint.CHAT_STREAM.value if stream - else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value + else WatsonXAIEndpoint.CHAT.value ) url = url.rstrip("/") + endpoint diff --git a/litellm/llms/watsonx/common_utils.py b/litellm/llms/watsonx/common_utils.py index ce1a16ca55..50fefc4da8 100644 --- a/litellm/llms/watsonx/common_utils.py +++ b/litellm/llms/watsonx/common_utils.py @@ -53,7 +53,9 @@ def generate_iam_token(api_key=None, **params) -> str: headers, data, ) - response = httpx.post(iam_token_url, data=data, headers=headers) + response = litellm.module_level_client.post( + url=iam_token_url, data=data, headers=headers + ) response.raise_for_status() json_data = response.json() @@ -165,10 +167,13 @@ class IBMWatsonXMixin: optional_params: Dict, api_key: Optional[str] = None, ) -> Dict: - headers = { + default_headers = { "Content-Type": "application/json", "Accept": "application/json", } + + if "Authorization" in headers: + return {**default_headers, **headers} token = cast(Optional[str], optional_params.get("token")) if token: headers["Authorization"] = f"Bearer {token}" @@ -176,7 +181,7 @@ class IBMWatsonXMixin: token = _generate_watsonx_token(api_key=api_key, token=token) # build auth headers headers["Authorization"] = f"Bearer {token}" - return headers + return {**default_headers, **headers} def _get_base_url(self, api_base: Optional[str]) -> str: url = ( diff --git a/tests/llm_translation/test_databricks.py b/tests/llm_translation/test_databricks.py index 9ea6b6f576..1b3a58fe70 100644 --- a/tests/llm_translation/test_databricks.py +++ b/tests/llm_translation/test_databricks.py @@ -278,7 +278,7 @@ def test_completions_with_sync_http_handler(monkeypatch): assert response.to_dict() == expected_response_json mock_post.assert_called_once_with( - f"{base_url}/chat/completions", + url=f"{base_url}/chat/completions", headers={ "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", diff --git a/tests/llm_translation/test_huggingface.py b/tests/llm_translation/test_huggingface.py new file mode 100644 index 0000000000..b99803d59a --- /dev/null +++ b/tests/llm_translation/test_huggingface.py @@ -0,0 +1,169 @@ +""" +Unit Tests Huggingface route +""" + +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm import completion, acompletion +from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler +from unittest.mock import patch, MagicMock, AsyncMock, Mock +import pytest + + +def tgi_mock_post(url, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = [ + { + "generated_text": "<|assistant|>\nI'm", + "details": { + "finish_reason": "length", + "generated_tokens": 10, + "seed": None, + "prefill": [], + "tokens": [ + { + "id": 28789, + "text": "<", + "logprob": -0.025222778, + "special": False, + }, + { + "id": 28766, + "text": "|", + "logprob": -0.000003695488, + "special": False, + }, + { + "id": 489, + "text": "ass", + "logprob": -0.0000019073486, + "special": False, + }, + { + "id": 11143, + "text": "istant", + "logprob": -0.000002026558, + "special": False, + }, + { + "id": 28766, + "text": "|", + "logprob": -0.0000015497208, + "special": False, + }, + { + "id": 28767, + "text": ">", + "logprob": -0.0000011920929, + "special": False, + }, + { + "id": 13, + "text": "\n", + "logprob": -0.00009703636, + "special": False, + }, + {"id": 28737, "text": "I", "logprob": -0.1953125, "special": False}, + { + "id": 28742, + "text": "'", + "logprob": -0.88183594, + "special": False, + }, + { + "id": 28719, + "text": "m", + "logprob": -0.00032639503, + "special": False, + }, + ], + }, + } + ] + return mock_response + + +@pytest.fixture +def huggingface_chat_completion_call(): + def _call( + model="huggingface/my-test-model", + messages=None, + api_key="test_api_key", + headers=None, + client=None, + ): + if messages is None: + messages = [{"role": "user", "content": "Hello, how are you?"}] + if client is None: + client = HTTPHandler() + + mock_response = Mock() + + with patch.object(client, "post", side_effect=tgi_mock_post) as mock_post: + completion( + model=model, + messages=messages, + api_key=api_key, + headers=headers or {}, + client=client, + ) + + return mock_post + + return _call + + +@pytest.fixture +def async_huggingface_chat_completion_call(): + async def _call( + model="huggingface/my-test-model", + messages=None, + api_key="test_api_key", + headers=None, + client=None, + ): + if messages is None: + messages = [{"role": "user", "content": "Hello, how are you?"}] + if client is None: + client = AsyncHTTPHandler() + + with patch.object(client, "post", side_effect=tgi_mock_post) as mock_post: + await acompletion( + model=model, + messages=messages, + api_key=api_key, + headers=headers or {}, + client=client, + ) + + return mock_post + + return _call + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_huggingface_chat_completions_endpoint( + sync_mode, huggingface_chat_completion_call, async_huggingface_chat_completion_call +): + model = "huggingface/another-model" + messages = [{"role": "user", "content": "Test message"}] + + if sync_mode: + mock_post = huggingface_chat_completion_call(model=model, messages=messages) + else: + mock_post = await async_huggingface_chat_completion_call( + model=model, messages=messages + ) + + assert mock_post.call_count == 1 diff --git a/tests/llm_translation/test_watsonx.py b/tests/llm_translation/test_watsonx.py new file mode 100644 index 0000000000..9efe28ceec --- /dev/null +++ b/tests/llm_translation/test_watsonx.py @@ -0,0 +1,87 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm import completion +from litellm.llms.watsonx.common_utils import IBMWatsonXMixin +from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler +from unittest.mock import patch, MagicMock, AsyncMock, Mock +import pytest + + +@pytest.fixture +def watsonx_chat_completion_call(): + def _call( + model="watsonx/my-test-model", + messages=None, + api_key="test_api_key", + headers=None, + client=None, + ): + if messages is None: + messages = [{"role": "user", "content": "Hello, how are you?"}] + if client is None: + client = HTTPHandler() + + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "mock_access_token", + "expires_in": 3600, + } + mock_response.raise_for_status = Mock() # No-op to simulate no exception + + with patch.object(client, "post") as mock_post, patch.object( + litellm.module_level_client, "post", return_value=mock_response + ) as mock_get: + completion( + model=model, + messages=messages, + api_key=api_key, + headers=headers or {}, + client=client, + ) + + return mock_post, mock_get + + return _call + + +@pytest.mark.parametrize("with_custom_auth_header", [True, False]) +def test_watsonx_custom_auth_header( + with_custom_auth_header, watsonx_chat_completion_call +): + headers = ( + {"Authorization": "Bearer my-custom-auth-header"} + if with_custom_auth_header + else {} + ) + + mock_post, _ = watsonx_chat_completion_call(headers=headers) + + assert mock_post.call_count == 1 + if with_custom_auth_header: + assert ( + mock_post.call_args[1]["headers"]["Authorization"] + == "Bearer my-custom-auth-header" + ) + else: + assert ( + mock_post.call_args[1]["headers"]["Authorization"] + == "Bearer mock_access_token" + ) + + +def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call): + model = "watsonx/another-model" + messages = [{"role": "user", "content": "Test message"}] + + mock_post, _ = watsonx_chat_completion_call(model=model, messages=messages) + + assert mock_post.call_count == 1 + assert "deployment" not in mock_post.call_args.kwargs["url"]