Merge pull request #4477 from BerriAI/litellm_fix_exception_mapping

[Fix]  - Error str in OpenAI, Azure exception
This commit is contained in:
Ishaan Jaff 2024-06-29 17:37:26 -07:00 committed by GitHub
commit db721e7137
4 changed files with 101 additions and 31 deletions

View file

@ -0,0 +1,40 @@
import json
from typing import Optional
def get_error_message(error_obj) -> Optional[str]:
"""
OpenAI Returns Error message that is nested, this extract the message
Example:
{
'request': "<Request('POST', 'https://api.openai.com/v1/chat/completions')>",
'message': "Error code: 400 - {\'error\': {\'message\': \"Invalid 'temperature': decimal above maximum value. Expected a value <= 2, but got 200 instead.\", 'type': 'invalid_request_error', 'param': 'temperature', 'code': 'decimal_above_max_value'}}",
'body': {
'message': "Invalid 'temperature': decimal above maximum value. Expected a value <= 2, but got 200 instead.",
'type': 'invalid_request_error',
'param': 'temperature',
'code': 'decimal_above_max_value'
},
'code': 'decimal_above_max_value',
'param': 'temperature',
'type': 'invalid_request_error',
'response': "<Response [400 Bad Request]>",
'status_code': 400,
'request_id': 'req_f287898caa6364cd42bc01355f74dd2a'
}
"""
try:
# First, try to access the message directly from the 'body' key
if error_obj is None:
return None
if hasattr(error_obj, "body"):
_error_obj_body = getattr(error_obj, "body")
if isinstance(_error_obj_body, dict):
return _error_obj_body.get("message")
# If all else fails, return None
return None
except Exception as e:
return None

View file

@ -1,25 +1,31 @@
# test that the proxy actually does exception mapping to the OpenAI format # test that the proxy actually does exception mapping to the OpenAI format
import sys, os
from unittest import mock
import json import json
import os
import sys
from unittest import mock
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import os, io, asyncio import asyncio
import io
import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import openai
import pytest import pytest
import litellm, openai
from fastapi.testclient import TestClient
from fastapi import Response from fastapi import Response
from litellm.proxy.proxy_server import ( from fastapi.testclient import TestClient
import litellm
from litellm.proxy.proxy_server import ( # Replace with the actual module where your FastAPI router is defined
initialize,
router, router,
save_worker_config, save_worker_config,
initialize, )
) # Replace with the actual module where your FastAPI router is defined
invalid_authentication_error_response = Response( invalid_authentication_error_response = Response(
status_code=401, status_code=401,
@ -66,6 +72,12 @@ def test_chat_completion_exception(client):
json_response = response.json() json_response = response.json()
print("keys in json response", json_response.keys()) print("keys in json response", json_response.keys())
assert json_response.keys() == {"error"} assert json_response.keys() == {"error"}
print("ERROR=", json_response["error"])
assert isinstance(json_response["error"]["message"], str)
assert (
json_response["error"]["message"]
== "litellm.AuthenticationError: AuthenticationError: OpenAIException - Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys."
)
# make an openai client to call _make_status_error_from_response # make an openai client to call _make_status_error_from_response
openai_client = openai.OpenAI(api_key="anything") openai_client = openai.OpenAI(api_key="anything")

View file

@ -1,16 +1,23 @@
import sys, os, time import asyncio
import traceback, asyncio import os
import sys
import time
import traceback
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm, asyncio, logging import asyncio
import logging
import litellm
from litellm import Router from litellm import Router
# this tests debug logs from litellm router and litellm proxy server # this tests debug logs from litellm router and litellm proxy server
from litellm._logging import verbose_router_logger, verbose_logger, verbose_proxy_logger from litellm._logging import verbose_logger, verbose_proxy_logger, verbose_router_logger
# this tests debug logs from litellm router and litellm proxy server # this tests debug logs from litellm router and litellm proxy server
@ -81,7 +88,7 @@ def test_async_fallbacks(caplog):
# Define the expected log messages # Define the expected log messages
# - error request, falling back notice, success notice # - error request, falling back notice, success notice
expected_logs = [ expected_logs = [
"litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception litellm.AuthenticationError: AuthenticationError: OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}\x1b[0m", "litellm.acompletion(model=gpt-3.5-turbo)\x1b[31m Exception litellm.AuthenticationError: AuthenticationError: OpenAIException - Incorrect API key provided: bad-key. You can find your API key at https://platform.openai.com/account/api-keys.\x1b[0m",
"Falling back to model_group = azure/gpt-3.5-turbo", "Falling back to model_group = azure/gpt-3.5-turbo",
"litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m", "litellm.acompletion(model=azure/chatgpt-v-2)\x1b[32m 200 OK\x1b[0m",
"Successful fallback b/w models.", "Successful fallback b/w models.",

View file

@ -50,6 +50,7 @@ import litellm._service_logger # for storing API inputs, outputs, and metadata
import litellm.litellm_core_utils import litellm.litellm_core_utils
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.exception_mapping_utils import get_error_message
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
from litellm.litellm_core_utils.redact_messages import ( from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging, redact_message_input_output_from_logging,
@ -5815,10 +5816,13 @@ def exception_type(
or custom_llm_provider in litellm.openai_compatible_providers or custom_llm_provider in litellm.openai_compatible_providers
): ):
# custom_llm_provider is openai, make it OpenAI # custom_llm_provider is openai, make it OpenAI
if hasattr(original_exception, "message"): message = get_error_message(error_obj=original_exception)
message = original_exception.message if message is None:
else: if hasattr(original_exception, "message"):
message = str(original_exception) message = original_exception.message
else:
message = str(original_exception)
if message is not None and isinstance(message, str): if message is not None and isinstance(message, str):
message = message.replace("OPENAI", custom_llm_provider.upper()) message = message.replace("OPENAI", custom_llm_provider.upper())
message = message.replace("openai", custom_llm_provider) message = message.replace("openai", custom_llm_provider)
@ -7271,10 +7275,17 @@ def exception_type(
request=original_exception.request, request=original_exception.request,
) )
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
message = get_error_message(error_obj=original_exception)
if message is None:
if hasattr(original_exception, "message"):
message = original_exception.message
else:
message = str(original_exception)
if "Internal server error" in error_str: if "Internal server error" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise litellm.InternalServerError( raise litellm.InternalServerError(
message=f"AzureException Internal server error - {original_exception.message}", message=f"AzureException Internal server error - {message}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -7287,7 +7298,7 @@ def exception_type(
elif "This model's maximum context length is" in error_str: elif "This model's maximum context length is" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise ContextWindowExceededError( raise ContextWindowExceededError(
message=f"AzureException ContextWindowExceededError - {original_exception.message}", message=f"AzureException ContextWindowExceededError - {message}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -7296,7 +7307,7 @@ def exception_type(
elif "DeploymentNotFound" in error_str: elif "DeploymentNotFound" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise NotFoundError( raise NotFoundError(
message=f"AzureException NotFoundError - {original_exception.message}", message=f"AzureException NotFoundError - {message}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -7316,7 +7327,7 @@ def exception_type(
): ):
exception_mapping_worked = True exception_mapping_worked = True
raise ContentPolicyViolationError( raise ContentPolicyViolationError(
message=f"litellm.ContentPolicyViolationError: AzureException - {original_exception.message}", message=f"litellm.ContentPolicyViolationError: AzureException - {message}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -7325,7 +7336,7 @@ def exception_type(
elif "invalid_request_error" in error_str: elif "invalid_request_error" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
message=f"AzureException BadRequestError - {original_exception.message}", message=f"AzureException BadRequestError - {message}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -7337,7 +7348,7 @@ def exception_type(
): ):
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message=f"{exception_provider} AuthenticationError - {original_exception.message}", message=f"{exception_provider} AuthenticationError - {message}",
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -7348,7 +7359,7 @@ def exception_type(
if original_exception.status_code == 400: if original_exception.status_code == 400:
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
message=f"AzureException - {original_exception.message}", message=f"AzureException - {message}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -7357,7 +7368,7 @@ def exception_type(
elif original_exception.status_code == 401: elif original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True
raise AuthenticationError( raise AuthenticationError(
message=f"AzureException AuthenticationError - {original_exception.message}", message=f"AzureException AuthenticationError - {message}",
llm_provider="azure", llm_provider="azure",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -7366,7 +7377,7 @@ def exception_type(
elif original_exception.status_code == 408: elif original_exception.status_code == 408:
exception_mapping_worked = True exception_mapping_worked = True
raise Timeout( raise Timeout(
message=f"AzureException Timeout - {original_exception.message}", message=f"AzureException Timeout - {message}",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
llm_provider="azure", llm_provider="azure",
@ -7374,7 +7385,7 @@ def exception_type(
elif original_exception.status_code == 422: elif original_exception.status_code == 422:
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
message=f"AzureException BadRequestError - {original_exception.message}", message=f"AzureException BadRequestError - {message}",
model=model, model=model,
llm_provider="azure", llm_provider="azure",
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -7383,7 +7394,7 @@ def exception_type(
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError( raise RateLimitError(
message=f"AzureException RateLimitError - {original_exception.message}", message=f"AzureException RateLimitError - {message}",
model=model, model=model,
llm_provider="azure", llm_provider="azure",
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -7392,7 +7403,7 @@ def exception_type(
elif original_exception.status_code == 503: elif original_exception.status_code == 503:
exception_mapping_worked = True exception_mapping_worked = True
raise ServiceUnavailableError( raise ServiceUnavailableError(
message=f"AzureException ServiceUnavailableError - {original_exception.message}", message=f"AzureException ServiceUnavailableError - {message}",
model=model, model=model,
llm_provider="azure", llm_provider="azure",
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
@ -7401,7 +7412,7 @@ def exception_type(
elif original_exception.status_code == 504: # gateway timeout error elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True exception_mapping_worked = True
raise Timeout( raise Timeout(
message=f"AzureException Timeout - {original_exception.message}", message=f"AzureException Timeout - {message}",
model=model, model=model,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
llm_provider="azure", llm_provider="azure",
@ -7410,7 +7421,7 @@ def exception_type(
exception_mapping_worked = True exception_mapping_worked = True
raise APIError( raise APIError(
status_code=original_exception.status_code, status_code=original_exception.status_code,
message=f"AzureException APIError - {original_exception.message}", message=f"AzureException APIError - {message}",
llm_provider="azure", llm_provider="azure",
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
model=model, model=model,