From 1dafb1b3b7a4c5d070efb4d3f55c86787d139183 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 8 Jun 2024 14:32:43 -0700 Subject: [PATCH] fix(utils.py): improved predibase exception mapping adds unit testing + better coverage for predibase errors --- litellm/exceptions.py | 62 ++++++++++++++--- litellm/integrations/test_httpx.py | 0 litellm/llms/predibase.py | 48 +++++++------ litellm/main.py | 14 ++-- litellm/proxy/_experimental/out/404.html | 1 - .../proxy/_experimental/out/model_hub.html | 1 - .../proxy/_experimental/out/onboarding.html | 1 - litellm/proxy/_super_secret_config.yaml | 22 ++++-- litellm/tests/test_exceptions.py | 46 +++++++++++++ litellm/utils.py | 69 +++++++++++++++++++ ruff.toml | 2 +- 11 files changed, 220 insertions(+), 46 deletions(-) create mode 100644 litellm/integrations/test_httpx.py delete mode 100644 litellm/proxy/_experimental/out/404.html delete mode 100644 litellm/proxy/_experimental/out/model_hub.html delete mode 100644 litellm/proxy/_experimental/out/onboarding.html diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 484e843b6..886b5889d 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -20,7 +20,7 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore message, llm_provider, model, - response: httpx.Response, + response: Optional[httpx.Response] = None, litellm_debug_info: Optional[str] = None, max_retries: Optional[int] = None, num_retries: Optional[int] = None, @@ -32,8 +32,14 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore self.litellm_debug_info = litellm_debug_info self.max_retries = max_retries self.num_retries = num_retries + self.response = response or httpx.Response( + status_code=self.status_code, + request=httpx.Request( + method="GET", url="https://litellm.ai" + ), # mock request object + ) super().__init__( - self.message, response=response, body=None + self.message, response=self.response, body=None ) # Call the base class constructor with the parameters it needs def __str__(self): @@ -60,7 +66,7 @@ class NotFoundError(openai.NotFoundError): # type: ignore message, model, llm_provider, - response: httpx.Response, + response: Optional[httpx.Response] = None, litellm_debug_info: Optional[str] = None, max_retries: Optional[int] = None, num_retries: Optional[int] = None, @@ -72,8 +78,14 @@ class NotFoundError(openai.NotFoundError): # type: ignore self.litellm_debug_info = litellm_debug_info self.max_retries = max_retries self.num_retries = num_retries + self.response = response or httpx.Response( + status_code=self.status_code, + request=httpx.Request( + method="GET", url="https://litellm.ai" + ), # mock request object + ) super().__init__( - self.message, response=response, body=None + self.message, response=self.response, body=None ) # Call the base class constructor with the parameters it needs def __str__(self): @@ -262,7 +274,7 @@ class RateLimitError(openai.RateLimitError): # type: ignore message, llm_provider, model, - response: httpx.Response, + response: Optional[httpx.Response] = None, litellm_debug_info: Optional[str] = None, max_retries: Optional[int] = None, num_retries: Optional[int] = None, @@ -274,8 +286,18 @@ class RateLimitError(openai.RateLimitError): # type: ignore self.litellm_debug_info = litellm_debug_info self.max_retries = max_retries self.num_retries = num_retries + if response is None: + self.response = httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ) + else: + self.response = response super().__init__( - self.message, response=response, body=None + self.message, response=self.response, body=None ) # Call the base class constructor with the parameters it needs def __str__(self): @@ -421,7 +443,7 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore message, llm_provider, model, - response: httpx.Response, + response: Optional[httpx.Response] = None, litellm_debug_info: Optional[str] = None, max_retries: Optional[int] = None, num_retries: Optional[int] = None, @@ -433,8 +455,18 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore self.litellm_debug_info = litellm_debug_info self.max_retries = max_retries self.num_retries = num_retries + if response is None: + self.response = httpx.Response( + status_code=self.status_code, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ) + else: + self.response = response super().__init__( - self.message, response=response, body=None + self.message, response=self.response, body=None ) # Call the base class constructor with the parameters it needs def __str__(self): @@ -460,7 +492,7 @@ class InternalServerError(openai.InternalServerError): # type: ignore message, llm_provider, model, - response: httpx.Response, + response: Optional[httpx.Response] = None, litellm_debug_info: Optional[str] = None, max_retries: Optional[int] = None, num_retries: Optional[int] = None, @@ -472,8 +504,18 @@ class InternalServerError(openai.InternalServerError): # type: ignore self.litellm_debug_info = litellm_debug_info self.max_retries = max_retries self.num_retries = num_retries + if response is None: + self.response = httpx.Response( + status_code=self.status_code, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ) + else: + self.response = response super().__init__( - self.message, response=response, body=None + self.message, response=self.response, body=None ) # Call the base class constructor with the parameters it needs def __str__(self): diff --git a/litellm/integrations/test_httpx.py b/litellm/integrations/test_httpx.py new file mode 100644 index 000000000..e69de29bb diff --git a/litellm/llms/predibase.py b/litellm/llms/predibase.py index a3245cdac..66c28acee 100644 --- a/litellm/llms/predibase.py +++ b/litellm/llms/predibase.py @@ -3,6 +3,7 @@ from functools import partial import os, types +import traceback import json from enum import Enum import requests, copy # type: ignore @@ -242,12 +243,12 @@ class PredibaseChatCompletion(BaseLLM): "details" in completion_response and "tokens" in completion_response["details"] ): - model_response.choices[0].finish_reason = completion_response[ - "details" - ]["finish_reason"] + model_response.choices[0].finish_reason = map_finish_reason( + completion_response["details"]["finish_reason"] + ) sum_logprob = 0 for token in completion_response["details"]["tokens"]: - if token["logprob"] != None: + if token["logprob"] is not None: sum_logprob += token["logprob"] model_response["choices"][0][ "message" @@ -265,7 +266,7 @@ class PredibaseChatCompletion(BaseLLM): ): sum_logprob = 0 for token in item["tokens"]: - if token["logprob"] != None: + if token["logprob"] is not None: sum_logprob += token["logprob"] if len(item["generated_text"]) > 0: message_obj = Message( @@ -275,7 +276,7 @@ class PredibaseChatCompletion(BaseLLM): else: message_obj = Message(content=None) choice_obj = Choices( - finish_reason=item["finish_reason"], + finish_reason=map_finish_reason(item["finish_reason"]), index=idx + 1, message=message_obj, ) @@ -285,10 +286,8 @@ class PredibaseChatCompletion(BaseLLM): ## CALCULATING USAGE prompt_tokens = 0 try: - prompt_tokens = len( - encoding.encode(model_response["choices"][0]["message"]["content"]) - ) ##[TODO] use a model-specific tokenizer here - except: + prompt_tokens = litellm.token_counter(messages=messages) + except Exception: # this should remain non blocking we should not block a response returning if calculating usage fails pass output_text = model_response["choices"][0]["message"].get("content", "") @@ -331,6 +330,7 @@ class PredibaseChatCompletion(BaseLLM): logging_obj, optional_params: dict, tenant_id: str, + timeout: Union[float, httpx.Timeout], acompletion=None, litellm_params=None, logger_fn=None, @@ -340,6 +340,7 @@ class PredibaseChatCompletion(BaseLLM): completion_url = "" input_text = "" base_url = "https://serving.app.predibase.com" + if "https" in model: completion_url = model elif api_base: @@ -349,7 +350,7 @@ class PredibaseChatCompletion(BaseLLM): completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}" - if optional_params.get("stream", False) == True: + if optional_params.get("stream", False) is True: completion_url += "/generate_stream" else: completion_url += "/generate" @@ -393,9 +394,9 @@ class PredibaseChatCompletion(BaseLLM): }, ) ## COMPLETION CALL - if acompletion == True: + if acompletion is True: ### ASYNC STREAMING - if stream == True: + if stream is True: return self.async_streaming( model=model, messages=messages, @@ -410,6 +411,7 @@ class PredibaseChatCompletion(BaseLLM): litellm_params=litellm_params, logger_fn=logger_fn, headers=headers, + timeout=timeout, ) # type: ignore else: ### ASYNC COMPLETION @@ -428,10 +430,11 @@ class PredibaseChatCompletion(BaseLLM): litellm_params=litellm_params, logger_fn=logger_fn, headers=headers, + timeout=timeout, ) # type: ignore ### SYNC STREAMING - if stream == True: + if stream is True: response = requests.post( completion_url, headers=headers, @@ -452,7 +455,6 @@ class PredibaseChatCompletion(BaseLLM): headers=headers, data=json.dumps(data), ) - return self.process_response( model=model, response=response, @@ -480,23 +482,26 @@ class PredibaseChatCompletion(BaseLLM): stream, data: dict, optional_params: dict, + timeout: Union[float, httpx.Timeout], litellm_params=None, logger_fn=None, headers={}, ) -> ModelResponse: - self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) - ) + + async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout)) try: - response = await self.async_handler.post( + response = await async_handler.post( api_base, headers=headers, data=json.dumps(data) ) except httpx.HTTPStatusError as e: raise PredibaseError( - status_code=e.response.status_code, message=e.response.text + status_code=e.response.status_code, + message="HTTPStatusError - {}".format(e.response.text), ) except Exception as e: - raise PredibaseError(status_code=500, message=str(e)) + raise PredibaseError( + status_code=500, message="{}\n{}".format(str(e), traceback.format_exc()) + ) return self.process_response( model=model, response=response, @@ -522,6 +527,7 @@ class PredibaseChatCompletion(BaseLLM): api_key, logging_obj, data: dict, + timeout: Union[float, httpx.Timeout], optional_params=None, litellm_params=None, logger_fn=None, diff --git a/litellm/main.py b/litellm/main.py index dd1fdb9f9..2c906e990 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -432,9 +432,9 @@ def mock_completion( if isinstance(mock_response, openai.APIError): raise mock_response raise litellm.APIError( - status_code=500, # type: ignore - message=str(mock_response), - llm_provider="openai", # type: ignore + status_code=getattr(mock_response, "status_code", 500), # type: ignore + message=getattr(mock_response, "text", str(mock_response)), + llm_provider=getattr(mock_response, "llm_provider", "openai"), # type: ignore model=model, # type: ignore request=httpx.Request(method="POST", url="https://api.openai.com/v1/"), ) @@ -1949,7 +1949,8 @@ def completion( ) api_base = ( - optional_params.pop("api_base", None) + api_base + or optional_params.pop("api_base", None) or optional_params.pop("base_url", None) or litellm.api_base or get_secret("PREDIBASE_API_BASE") @@ -1977,12 +1978,13 @@ def completion( custom_prompt_dict=custom_prompt_dict, api_key=api_key, tenant_id=tenant_id, + timeout=timeout, ) if ( "stream" in optional_params - and optional_params["stream"] == True - and acompletion == False + and optional_params["stream"] is True + and acompletion is False ): return _model_response response = _model_response diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404.html deleted file mode 100644 index 4e1a8fa4c..000000000 --- a/litellm/proxy/_experimental/out/404.html +++ /dev/null @@ -1 +0,0 @@ -404: This page could not be found.LiteLLM Dashboard

404

This page could not be found.

\ No newline at end of file diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub.html deleted file mode 100644 index 4da591c1c..000000000 --- a/litellm/proxy/_experimental/out/model_hub.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index 197ab5c31..000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 7fa1bbc19..2b2054756 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -8,6 +8,17 @@ model_list: - model_name: llama3-70b-8192 litellm_params: model: groq/llama3-70b-8192 +- model_name: fake-openai-endpoint + litellm_params: + model: predibase/llama-3-8b-instruct + api_base: "http://0.0.0.0:8081" + api_key: os.environ/PREDIBASE_API_KEY + tenant_id: os.environ/PREDIBASE_TENANT_ID + max_retries: 0 + temperature: 0.1 + max_new_tokens: 256 + return_full_text: false + # - litellm_params: # api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ # api_key: os.environ/AZURE_EUROPE_API_KEY @@ -56,10 +67,11 @@ router_settings: litellm_settings: success_callback: ["langfuse"] + failure_callback: ["langfuse"] -general_settings: - alerting: ["email"] - key_management_system: "aws_kms" - key_management_settings: - hosted_keys: ["LITELLM_MASTER_KEY"] +# general_settings: +# alerting: ["email"] +# key_management_system: "aws_kms" +# key_management_settings: +# hosted_keys: ["LITELLM_MASTER_KEY"] diff --git a/litellm/tests/test_exceptions.py b/litellm/tests/test_exceptions.py index ee695dcd7..1082dd2f8 100644 --- a/litellm/tests/test_exceptions.py +++ b/litellm/tests/test_exceptions.py @@ -3,6 +3,7 @@ import os import sys import traceback import subprocess, asyncio +from typing import Any sys.path.insert( 0, os.path.abspath("../..") @@ -19,6 +20,7 @@ from litellm import ( ) from concurrent.futures import ThreadPoolExecutor import pytest +from unittest.mock import patch, MagicMock litellm.vertex_project = "pathrise-convert-1606954137718" litellm.vertex_location = "us-central1" @@ -655,3 +657,47 @@ def test_litellm_predibase_exception(): # accuracy_score = counts[True]/(counts[True] + counts[False]) # print(f"accuracy_score: {accuracy_score}") + + +@pytest.mark.parametrize("provider", ["predibase"]) +def test_exception_mapping(provider): + """ + For predibase, run through a set of mock exceptions + + assert that they are being mapped correctly + """ + litellm.set_verbose = True + error_map = { + 400: litellm.BadRequestError, + 401: litellm.AuthenticationError, + 404: litellm.NotFoundError, + 408: litellm.Timeout, + 429: litellm.RateLimitError, + 500: litellm.InternalServerError, + 503: litellm.ServiceUnavailableError, + } + + for code, expected_exception in error_map.items(): + mock_response = Exception() + setattr(mock_response, "text", "This is an error message") + setattr(mock_response, "llm_provider", provider) + setattr(mock_response, "status_code", code) + + response: Any = None + try: + response = completion( + model="{}/test-model".format(provider), + messages=[{"role": "user", "content": "Hey, how's it going?"}], + mock_response=mock_response, + ) + except expected_exception: + continue + except Exception as e: + response = "{}\n{}".format(str(e), traceback.format_exc()) + pytest.fail( + "Did not raise expected exception. Expected={}, Return={},".format( + expected_exception, response + ) + ) + + pass diff --git a/litellm/utils.py b/litellm/utils.py index 410f9ad88..bc55e6ef7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8978,6 +8978,75 @@ def exception_type( response=original_exception.response, litellm_debug_info=extra_information, ) + elif hasattr(original_exception, "status_code"): + if original_exception.status_code == 500: + exception_mapping_worked = True + raise litellm.InternalServerError( + message=f"PredibaseException - {original_exception.message}", + llm_provider="predibase", + model=model, + ) + elif original_exception.status_code == 401: + exception_mapping_worked = True + raise AuthenticationError( + message=f"PredibaseException - {original_exception.message}", + llm_provider="predibase", + model=model, + ) + elif original_exception.status_code == 400: + exception_mapping_worked = True + raise BadRequestError( + message=f"PredibaseException - {original_exception.message}", + llm_provider="predibase", + model=model, + ) + elif original_exception.status_code == 404: + exception_mapping_worked = True + raise NotFoundError( + message=f"PredibaseException - {original_exception.message}", + llm_provider="predibase", + model=model, + ) + elif original_exception.status_code == 408: + exception_mapping_worked = True + raise Timeout( + message=f"PredibaseException - {original_exception.message}", + model=model, + llm_provider=custom_llm_provider, + litellm_debug_info=extra_information, + ) + elif original_exception.status_code == 422: + exception_mapping_worked = True + raise BadRequestError( + message=f"PredibaseException - {original_exception.message}", + model=model, + llm_provider=custom_llm_provider, + litellm_debug_info=extra_information, + ) + elif original_exception.status_code == 429: + exception_mapping_worked = True + raise RateLimitError( + message=f"PredibaseException - {original_exception.message}", + model=model, + llm_provider=custom_llm_provider, + litellm_debug_info=extra_information, + ) + elif original_exception.status_code == 503: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"PredibaseException - {original_exception.message}", + model=model, + llm_provider=custom_llm_provider, + litellm_debug_info=extra_information, + ) + elif original_exception.status_code == 504: # gateway timeout error + exception_mapping_worked = True + raise Timeout( + message=f"PredibaseException - {original_exception.message}", + model=model, + llm_provider=custom_llm_provider, + litellm_debug_info=extra_information, + ) elif custom_llm_provider == "bedrock": if ( "too many tokens" in error_str diff --git a/ruff.toml b/ruff.toml index dfb323c1b..4894ab3fc 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,3 +1,3 @@ -ignore = ["F405"] +ignore = ["F405", "E402"] extend-select = ["E501"] line-length = 120