forked from phoenix/litellm-mirror
fix(utils.py): improved predibase exception mapping
adds unit testing + better coverage for predibase errors
This commit is contained in:
parent
93a3a0cc1e
commit
1dafb1b3b7
11 changed files with 220 additions and 46 deletions
|
@ -20,7 +20,7 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
|
||||||
message,
|
message,
|
||||||
llm_provider,
|
llm_provider,
|
||||||
model,
|
model,
|
||||||
response: httpx.Response,
|
response: Optional[httpx.Response] = None,
|
||||||
litellm_debug_info: Optional[str] = None,
|
litellm_debug_info: Optional[str] = None,
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
|
@ -32,8 +32,14 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
self.num_retries = num_retries
|
self.num_retries = num_retries
|
||||||
|
self.response = response or httpx.Response(
|
||||||
|
status_code=self.status_code,
|
||||||
|
request=httpx.Request(
|
||||||
|
method="GET", url="https://litellm.ai"
|
||||||
|
), # mock request object
|
||||||
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self.message, response=response, body=None
|
self.message, response=self.response, body=None
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -60,7 +66,7 @@ class NotFoundError(openai.NotFoundError): # type: ignore
|
||||||
message,
|
message,
|
||||||
model,
|
model,
|
||||||
llm_provider,
|
llm_provider,
|
||||||
response: httpx.Response,
|
response: Optional[httpx.Response] = None,
|
||||||
litellm_debug_info: Optional[str] = None,
|
litellm_debug_info: Optional[str] = None,
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
|
@ -72,8 +78,14 @@ class NotFoundError(openai.NotFoundError): # type: ignore
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
self.num_retries = num_retries
|
self.num_retries = num_retries
|
||||||
|
self.response = response or httpx.Response(
|
||||||
|
status_code=self.status_code,
|
||||||
|
request=httpx.Request(
|
||||||
|
method="GET", url="https://litellm.ai"
|
||||||
|
), # mock request object
|
||||||
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self.message, response=response, body=None
|
self.message, response=self.response, body=None
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -262,7 +274,7 @@ class RateLimitError(openai.RateLimitError): # type: ignore
|
||||||
message,
|
message,
|
||||||
llm_provider,
|
llm_provider,
|
||||||
model,
|
model,
|
||||||
response: httpx.Response,
|
response: Optional[httpx.Response] = None,
|
||||||
litellm_debug_info: Optional[str] = None,
|
litellm_debug_info: Optional[str] = None,
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
|
@ -274,8 +286,18 @@ class RateLimitError(openai.RateLimitError): # type: ignore
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
self.num_retries = num_retries
|
self.num_retries = num_retries
|
||||||
|
if response is None:
|
||||||
|
self.response = httpx.Response(
|
||||||
|
status_code=429,
|
||||||
|
request=httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url=" https://cloud.google.com/vertex-ai/",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.response = response
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self.message, response=response, body=None
|
self.message, response=self.response, body=None
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -421,7 +443,7 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
||||||
message,
|
message,
|
||||||
llm_provider,
|
llm_provider,
|
||||||
model,
|
model,
|
||||||
response: httpx.Response,
|
response: Optional[httpx.Response] = None,
|
||||||
litellm_debug_info: Optional[str] = None,
|
litellm_debug_info: Optional[str] = None,
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
|
@ -433,8 +455,18 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
self.num_retries = num_retries
|
self.num_retries = num_retries
|
||||||
|
if response is None:
|
||||||
|
self.response = httpx.Response(
|
||||||
|
status_code=self.status_code,
|
||||||
|
request=httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url=" https://cloud.google.com/vertex-ai/",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.response = response
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self.message, response=response, body=None
|
self.message, response=self.response, body=None
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -460,7 +492,7 @@ class InternalServerError(openai.InternalServerError): # type: ignore
|
||||||
message,
|
message,
|
||||||
llm_provider,
|
llm_provider,
|
||||||
model,
|
model,
|
||||||
response: httpx.Response,
|
response: Optional[httpx.Response] = None,
|
||||||
litellm_debug_info: Optional[str] = None,
|
litellm_debug_info: Optional[str] = None,
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
num_retries: Optional[int] = None,
|
num_retries: Optional[int] = None,
|
||||||
|
@ -472,8 +504,18 @@ class InternalServerError(openai.InternalServerError): # type: ignore
|
||||||
self.litellm_debug_info = litellm_debug_info
|
self.litellm_debug_info = litellm_debug_info
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
self.num_retries = num_retries
|
self.num_retries = num_retries
|
||||||
|
if response is None:
|
||||||
|
self.response = httpx.Response(
|
||||||
|
status_code=self.status_code,
|
||||||
|
request=httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url=" https://cloud.google.com/vertex-ai/",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.response = response
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self.message, response=response, body=None
|
self.message, response=self.response, body=None
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|
0
litellm/integrations/test_httpx.py
Normal file
0
litellm/integrations/test_httpx.py
Normal file
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import os, types
|
import os, types
|
||||||
|
import traceback
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests, copy # type: ignore
|
import requests, copy # type: ignore
|
||||||
|
@ -242,12 +243,12 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
"details" in completion_response
|
"details" in completion_response
|
||||||
and "tokens" in completion_response["details"]
|
and "tokens" in completion_response["details"]
|
||||||
):
|
):
|
||||||
model_response.choices[0].finish_reason = completion_response[
|
model_response.choices[0].finish_reason = map_finish_reason(
|
||||||
"details"
|
completion_response["details"]["finish_reason"]
|
||||||
]["finish_reason"]
|
)
|
||||||
sum_logprob = 0
|
sum_logprob = 0
|
||||||
for token in completion_response["details"]["tokens"]:
|
for token in completion_response["details"]["tokens"]:
|
||||||
if token["logprob"] != None:
|
if token["logprob"] is not None:
|
||||||
sum_logprob += token["logprob"]
|
sum_logprob += token["logprob"]
|
||||||
model_response["choices"][0][
|
model_response["choices"][0][
|
||||||
"message"
|
"message"
|
||||||
|
@ -265,7 +266,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
):
|
):
|
||||||
sum_logprob = 0
|
sum_logprob = 0
|
||||||
for token in item["tokens"]:
|
for token in item["tokens"]:
|
||||||
if token["logprob"] != None:
|
if token["logprob"] is not None:
|
||||||
sum_logprob += token["logprob"]
|
sum_logprob += token["logprob"]
|
||||||
if len(item["generated_text"]) > 0:
|
if len(item["generated_text"]) > 0:
|
||||||
message_obj = Message(
|
message_obj = Message(
|
||||||
|
@ -275,7 +276,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
message_obj = Message(content=None)
|
message_obj = Message(content=None)
|
||||||
choice_obj = Choices(
|
choice_obj = Choices(
|
||||||
finish_reason=item["finish_reason"],
|
finish_reason=map_finish_reason(item["finish_reason"]),
|
||||||
index=idx + 1,
|
index=idx + 1,
|
||||||
message=message_obj,
|
message=message_obj,
|
||||||
)
|
)
|
||||||
|
@ -285,10 +286,8 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
prompt_tokens = 0
|
prompt_tokens = 0
|
||||||
try:
|
try:
|
||||||
prompt_tokens = len(
|
prompt_tokens = litellm.token_counter(messages=messages)
|
||||||
encoding.encode(model_response["choices"][0]["message"]["content"])
|
except Exception:
|
||||||
) ##[TODO] use a model-specific tokenizer here
|
|
||||||
except:
|
|
||||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||||
pass
|
pass
|
||||||
output_text = model_response["choices"][0]["message"].get("content", "")
|
output_text = model_response["choices"][0]["message"].get("content", "")
|
||||||
|
@ -331,6 +330,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
acompletion=None,
|
acompletion=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -340,6 +340,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
completion_url = ""
|
completion_url = ""
|
||||||
input_text = ""
|
input_text = ""
|
||||||
base_url = "https://serving.app.predibase.com"
|
base_url = "https://serving.app.predibase.com"
|
||||||
|
|
||||||
if "https" in model:
|
if "https" in model:
|
||||||
completion_url = model
|
completion_url = model
|
||||||
elif api_base:
|
elif api_base:
|
||||||
|
@ -349,7 +350,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
|
|
||||||
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
|
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
|
||||||
|
|
||||||
if optional_params.get("stream", False) == True:
|
if optional_params.get("stream", False) is True:
|
||||||
completion_url += "/generate_stream"
|
completion_url += "/generate_stream"
|
||||||
else:
|
else:
|
||||||
completion_url += "/generate"
|
completion_url += "/generate"
|
||||||
|
@ -393,9 +394,9 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
if acompletion == True:
|
if acompletion is True:
|
||||||
### ASYNC STREAMING
|
### ASYNC STREAMING
|
||||||
if stream == True:
|
if stream is True:
|
||||||
return self.async_streaming(
|
return self.async_streaming(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -410,6 +411,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
else:
|
else:
|
||||||
### ASYNC COMPLETION
|
### ASYNC COMPLETION
|
||||||
|
@ -428,10 +430,11 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
### SYNC STREAMING
|
### SYNC STREAMING
|
||||||
if stream == True:
|
if stream is True:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
completion_url,
|
completion_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
@ -452,7 +455,6 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.process_response(
|
return self.process_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
|
@ -480,23 +482,26 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
stream,
|
stream,
|
||||||
data: dict,
|
data: dict,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
self.async_handler = AsyncHTTPHandler(
|
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
response = await self.async_handler.post(
|
response = await async_handler.post(
|
||||||
api_base, headers=headers, data=json.dumps(data)
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
)
|
)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
raise PredibaseError(
|
raise PredibaseError(
|
||||||
status_code=e.response.status_code, message=e.response.text
|
status_code=e.response.status_code,
|
||||||
|
message="HTTPStatusError - {}".format(e.response.text),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise PredibaseError(status_code=500, message=str(e))
|
raise PredibaseError(
|
||||||
|
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
|
||||||
|
)
|
||||||
return self.process_response(
|
return self.process_response(
|
||||||
model=model,
|
model=model,
|
||||||
response=response,
|
response=response,
|
||||||
|
@ -522,6 +527,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
|
|
@ -432,9 +432,9 @@ def mock_completion(
|
||||||
if isinstance(mock_response, openai.APIError):
|
if isinstance(mock_response, openai.APIError):
|
||||||
raise mock_response
|
raise mock_response
|
||||||
raise litellm.APIError(
|
raise litellm.APIError(
|
||||||
status_code=500, # type: ignore
|
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
||||||
message=str(mock_response),
|
message=getattr(mock_response, "text", str(mock_response)),
|
||||||
llm_provider="openai", # type: ignore
|
llm_provider=getattr(mock_response, "llm_provider", "openai"), # type: ignore
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||||
)
|
)
|
||||||
|
@ -1949,7 +1949,8 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
api_base = (
|
api_base = (
|
||||||
optional_params.pop("api_base", None)
|
api_base
|
||||||
|
or optional_params.pop("api_base", None)
|
||||||
or optional_params.pop("base_url", None)
|
or optional_params.pop("base_url", None)
|
||||||
or litellm.api_base
|
or litellm.api_base
|
||||||
or get_secret("PREDIBASE_API_BASE")
|
or get_secret("PREDIBASE_API_BASE")
|
||||||
|
@ -1977,12 +1978,13 @@ def completion(
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and optional_params["stream"] is True
|
||||||
and acompletion == False
|
and acompletion is False
|
||||||
):
|
):
|
||||||
return _model_response
|
return _model_response
|
||||||
response = _model_response
|
response = _model_response
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -8,6 +8,17 @@ model_list:
|
||||||
- model_name: llama3-70b-8192
|
- model_name: llama3-70b-8192
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: groq/llama3-70b-8192
|
model: groq/llama3-70b-8192
|
||||||
|
- model_name: fake-openai-endpoint
|
||||||
|
litellm_params:
|
||||||
|
model: predibase/llama-3-8b-instruct
|
||||||
|
api_base: "http://0.0.0.0:8081"
|
||||||
|
api_key: os.environ/PREDIBASE_API_KEY
|
||||||
|
tenant_id: os.environ/PREDIBASE_TENANT_ID
|
||||||
|
max_retries: 0
|
||||||
|
temperature: 0.1
|
||||||
|
max_new_tokens: 256
|
||||||
|
return_full_text: false
|
||||||
|
|
||||||
# - litellm_params:
|
# - litellm_params:
|
||||||
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||||
# api_key: os.environ/AZURE_EUROPE_API_KEY
|
# api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||||
|
@ -56,10 +67,11 @@ router_settings:
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["langfuse"]
|
success_callback: ["langfuse"]
|
||||||
|
failure_callback: ["langfuse"]
|
||||||
|
|
||||||
general_settings:
|
# general_settings:
|
||||||
alerting: ["email"]
|
# alerting: ["email"]
|
||||||
key_management_system: "aws_kms"
|
# key_management_system: "aws_kms"
|
||||||
key_management_settings:
|
# key_management_settings:
|
||||||
hosted_keys: ["LITELLM_MASTER_KEY"]
|
# hosted_keys: ["LITELLM_MASTER_KEY"]
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import subprocess, asyncio
|
import subprocess, asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -19,6 +20,7 @@ from litellm import (
|
||||||
)
|
)
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
litellm.vertex_project = "pathrise-convert-1606954137718"
|
litellm.vertex_project = "pathrise-convert-1606954137718"
|
||||||
litellm.vertex_location = "us-central1"
|
litellm.vertex_location = "us-central1"
|
||||||
|
@ -655,3 +657,47 @@ def test_litellm_predibase_exception():
|
||||||
|
|
||||||
# accuracy_score = counts[True]/(counts[True] + counts[False])
|
# accuracy_score = counts[True]/(counts[True] + counts[False])
|
||||||
# print(f"accuracy_score: {accuracy_score}")
|
# print(f"accuracy_score: {accuracy_score}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("provider", ["predibase"])
|
||||||
|
def test_exception_mapping(provider):
|
||||||
|
"""
|
||||||
|
For predibase, run through a set of mock exceptions
|
||||||
|
|
||||||
|
assert that they are being mapped correctly
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
|
error_map = {
|
||||||
|
400: litellm.BadRequestError,
|
||||||
|
401: litellm.AuthenticationError,
|
||||||
|
404: litellm.NotFoundError,
|
||||||
|
408: litellm.Timeout,
|
||||||
|
429: litellm.RateLimitError,
|
||||||
|
500: litellm.InternalServerError,
|
||||||
|
503: litellm.ServiceUnavailableError,
|
||||||
|
}
|
||||||
|
|
||||||
|
for code, expected_exception in error_map.items():
|
||||||
|
mock_response = Exception()
|
||||||
|
setattr(mock_response, "text", "This is an error message")
|
||||||
|
setattr(mock_response, "llm_provider", provider)
|
||||||
|
setattr(mock_response, "status_code", code)
|
||||||
|
|
||||||
|
response: Any = None
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model="{}/test-model".format(provider),
|
||||||
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
mock_response=mock_response,
|
||||||
|
)
|
||||||
|
except expected_exception:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
response = "{}\n{}".format(str(e), traceback.format_exc())
|
||||||
|
pytest.fail(
|
||||||
|
"Did not raise expected exception. Expected={}, Return={},".format(
|
||||||
|
expected_exception, response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
|
@ -8978,6 +8978,75 @@ def exception_type(
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
litellm_debug_info=extra_information,
|
litellm_debug_info=extra_information,
|
||||||
)
|
)
|
||||||
|
elif hasattr(original_exception, "status_code"):
|
||||||
|
if original_exception.status_code == 500:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise litellm.InternalServerError(
|
||||||
|
message=f"PredibaseException - {original_exception.message}",
|
||||||
|
llm_provider="predibase",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
elif original_exception.status_code == 401:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise AuthenticationError(
|
||||||
|
message=f"PredibaseException - {original_exception.message}",
|
||||||
|
llm_provider="predibase",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
elif original_exception.status_code == 400:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise BadRequestError(
|
||||||
|
message=f"PredibaseException - {original_exception.message}",
|
||||||
|
llm_provider="predibase",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
elif original_exception.status_code == 404:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise NotFoundError(
|
||||||
|
message=f"PredibaseException - {original_exception.message}",
|
||||||
|
llm_provider="predibase",
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
elif original_exception.status_code == 408:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise Timeout(
|
||||||
|
message=f"PredibaseException - {original_exception.message}",
|
||||||
|
model=model,
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
litellm_debug_info=extra_information,
|
||||||
|
)
|
||||||
|
elif original_exception.status_code == 422:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise BadRequestError(
|
||||||
|
message=f"PredibaseException - {original_exception.message}",
|
||||||
|
model=model,
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
litellm_debug_info=extra_information,
|
||||||
|
)
|
||||||
|
elif original_exception.status_code == 429:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise RateLimitError(
|
||||||
|
message=f"PredibaseException - {original_exception.message}",
|
||||||
|
model=model,
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
litellm_debug_info=extra_information,
|
||||||
|
)
|
||||||
|
elif original_exception.status_code == 503:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise ServiceUnavailableError(
|
||||||
|
message=f"PredibaseException - {original_exception.message}",
|
||||||
|
model=model,
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
litellm_debug_info=extra_information,
|
||||||
|
)
|
||||||
|
elif original_exception.status_code == 504: # gateway timeout error
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise Timeout(
|
||||||
|
message=f"PredibaseException - {original_exception.message}",
|
||||||
|
model=model,
|
||||||
|
llm_provider=custom_llm_provider,
|
||||||
|
litellm_debug_info=extra_information,
|
||||||
|
)
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
if (
|
if (
|
||||||
"too many tokens" in error_str
|
"too many tokens" in error_str
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
ignore = ["F405"]
|
ignore = ["F405", "E402"]
|
||||||
extend-select = ["E501"]
|
extend-select = ["E501"]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue