fix(utils.py): improved predibase exception mapping

adds unit testing + better coverage for predibase errors
This commit is contained in:
Krrish Dholakia 2024-06-08 14:32:43 -07:00
parent 93a3a0cc1e
commit 1dafb1b3b7
11 changed files with 220 additions and 46 deletions

View file

@ -20,7 +20,7 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
message,
llm_provider,
model,
response: httpx.Response,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
@ -32,8 +32,14 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
self.response = response or httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
)
super().__init__(
self.message, response=response, body=None
self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs
def __str__(self):
@ -60,7 +66,7 @@ class NotFoundError(openai.NotFoundError): # type: ignore
message,
model,
llm_provider,
response: httpx.Response,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
@ -72,8 +78,14 @@ class NotFoundError(openai.NotFoundError): # type: ignore
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
self.response = response or httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
)
super().__init__(
self.message, response=response, body=None
self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs
def __str__(self):
@ -262,7 +274,7 @@ class RateLimitError(openai.RateLimitError): # type: ignore
message,
llm_provider,
model,
response: httpx.Response,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
@ -274,8 +286,18 @@ class RateLimitError(openai.RateLimitError): # type: ignore
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
if response is None:
self.response = httpx.Response(
status_code=429,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
)
else:
self.response = response
super().__init__(
self.message, response=response, body=None
self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs
def __str__(self):
@ -421,7 +443,7 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
message,
llm_provider,
model,
response: httpx.Response,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
@ -433,8 +455,18 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
if response is None:
self.response = httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
)
else:
self.response = response
super().__init__(
self.message, response=response, body=None
self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs
def __str__(self):
@ -460,7 +492,7 @@ class InternalServerError(openai.InternalServerError): # type: ignore
message,
llm_provider,
model,
response: httpx.Response,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
@ -472,8 +504,18 @@ class InternalServerError(openai.InternalServerError): # type: ignore
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
if response is None:
self.response = httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
)
else:
self.response = response
super().__init__(
self.message, response=response, body=None
self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs
def __str__(self):

View file

View file

@ -3,6 +3,7 @@
from functools import partial
import os, types
import traceback
import json
from enum import Enum
import requests, copy # type: ignore
@ -242,12 +243,12 @@ class PredibaseChatCompletion(BaseLLM):
"details" in completion_response
and "tokens" in completion_response["details"]
):
model_response.choices[0].finish_reason = completion_response[
"details"
]["finish_reason"]
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["details"]["finish_reason"]
)
sum_logprob = 0
for token in completion_response["details"]["tokens"]:
if token["logprob"] != None:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
model_response["choices"][0][
"message"
@ -265,7 +266,7 @@ class PredibaseChatCompletion(BaseLLM):
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] != None:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(
@ -275,7 +276,7 @@ class PredibaseChatCompletion(BaseLLM):
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finish_reason"],
finish_reason=map_finish_reason(item["finish_reason"]),
index=idx + 1,
message=message_obj,
)
@ -285,10 +286,8 @@ class PredibaseChatCompletion(BaseLLM):
## CALCULATING USAGE
prompt_tokens = 0
try:
prompt_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
) ##[TODO] use a model-specific tokenizer here
except:
prompt_tokens = litellm.token_counter(messages=messages)
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
output_text = model_response["choices"][0]["message"].get("content", "")
@ -331,6 +330,7 @@ class PredibaseChatCompletion(BaseLLM):
logging_obj,
optional_params: dict,
tenant_id: str,
timeout: Union[float, httpx.Timeout],
acompletion=None,
litellm_params=None,
logger_fn=None,
@ -340,6 +340,7 @@ class PredibaseChatCompletion(BaseLLM):
completion_url = ""
input_text = ""
base_url = "https://serving.app.predibase.com"
if "https" in model:
completion_url = model
elif api_base:
@ -349,7 +350,7 @@ class PredibaseChatCompletion(BaseLLM):
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
if optional_params.get("stream", False) == True:
if optional_params.get("stream", False) is True:
completion_url += "/generate_stream"
else:
completion_url += "/generate"
@ -393,9 +394,9 @@ class PredibaseChatCompletion(BaseLLM):
},
)
## COMPLETION CALL
if acompletion == True:
if acompletion is True:
### ASYNC STREAMING
if stream == True:
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
@ -410,6 +411,7 @@ class PredibaseChatCompletion(BaseLLM):
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
) # type: ignore
else:
### ASYNC COMPLETION
@ -428,10 +430,11 @@ class PredibaseChatCompletion(BaseLLM):
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
) # type: ignore
### SYNC STREAMING
if stream == True:
if stream is True:
response = requests.post(
completion_url,
headers=headers,
@ -452,7 +455,6 @@ class PredibaseChatCompletion(BaseLLM):
headers=headers,
data=json.dumps(data),
)
return self.process_response(
model=model,
response=response,
@ -480,23 +482,26 @@ class PredibaseChatCompletion(BaseLLM):
stream,
data: dict,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
litellm_params=None,
logger_fn=None,
headers={},
) -> ModelResponse:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
try:
response = await self.async_handler.post(
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
except httpx.HTTPStatusError as e:
raise PredibaseError(
status_code=e.response.status_code, message=e.response.text
status_code=e.response.status_code,
message="HTTPStatusError - {}".format(e.response.text),
)
except Exception as e:
raise PredibaseError(status_code=500, message=str(e))
raise PredibaseError(
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
)
return self.process_response(
model=model,
response=response,
@ -522,6 +527,7 @@ class PredibaseChatCompletion(BaseLLM):
api_key,
logging_obj,
data: dict,
timeout: Union[float, httpx.Timeout],
optional_params=None,
litellm_params=None,
logger_fn=None,

View file

@ -432,9 +432,9 @@ def mock_completion(
if isinstance(mock_response, openai.APIError):
raise mock_response
raise litellm.APIError(
status_code=500, # type: ignore
message=str(mock_response),
llm_provider="openai", # type: ignore
status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=getattr(mock_response, "text", str(mock_response)),
llm_provider=getattr(mock_response, "llm_provider", "openai"), # type: ignore
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
@ -1949,7 +1949,8 @@ def completion(
)
api_base = (
optional_params.pop("api_base", None)
api_base
or optional_params.pop("api_base", None)
or optional_params.pop("base_url", None)
or litellm.api_base
or get_secret("PREDIBASE_API_BASE")
@ -1977,12 +1978,13 @@ def completion(
custom_prompt_dict=custom_prompt_dict,
api_key=api_key,
tenant_id=tenant_id,
timeout=timeout,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
and optional_params["stream"] is True
and acompletion is False
):
return _model_response
response = _model_response

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -8,6 +8,17 @@ model_list:
- model_name: llama3-70b-8192
litellm_params:
model: groq/llama3-70b-8192
- model_name: fake-openai-endpoint
litellm_params:
model: predibase/llama-3-8b-instruct
api_base: "http://0.0.0.0:8081"
api_key: os.environ/PREDIBASE_API_KEY
tenant_id: os.environ/PREDIBASE_TENANT_ID
max_retries: 0
temperature: 0.1
max_new_tokens: 256
return_full_text: false
# - litellm_params:
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
# api_key: os.environ/AZURE_EUROPE_API_KEY
@ -56,10 +67,11 @@ router_settings:
litellm_settings:
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
general_settings:
alerting: ["email"]
key_management_system: "aws_kms"
key_management_settings:
hosted_keys: ["LITELLM_MASTER_KEY"]
# general_settings:
# alerting: ["email"]
# key_management_system: "aws_kms"
# key_management_settings:
# hosted_keys: ["LITELLM_MASTER_KEY"]

View file

@ -3,6 +3,7 @@ import os
import sys
import traceback
import subprocess, asyncio
from typing import Any
sys.path.insert(
0, os.path.abspath("../..")
@ -19,6 +20,7 @@ from litellm import (
)
from concurrent.futures import ThreadPoolExecutor
import pytest
from unittest.mock import patch, MagicMock
litellm.vertex_project = "pathrise-convert-1606954137718"
litellm.vertex_location = "us-central1"
@ -655,3 +657,47 @@ def test_litellm_predibase_exception():
# accuracy_score = counts[True]/(counts[True] + counts[False])
# print(f"accuracy_score: {accuracy_score}")
@pytest.mark.parametrize("provider", ["predibase"])
def test_exception_mapping(provider):
"""
For predibase, run through a set of mock exceptions
assert that they are being mapped correctly
"""
litellm.set_verbose = True
error_map = {
400: litellm.BadRequestError,
401: litellm.AuthenticationError,
404: litellm.NotFoundError,
408: litellm.Timeout,
429: litellm.RateLimitError,
500: litellm.InternalServerError,
503: litellm.ServiceUnavailableError,
}
for code, expected_exception in error_map.items():
mock_response = Exception()
setattr(mock_response, "text", "This is an error message")
setattr(mock_response, "llm_provider", provider)
setattr(mock_response, "status_code", code)
response: Any = None
try:
response = completion(
model="{}/test-model".format(provider),
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_response=mock_response,
)
except expected_exception:
continue
except Exception as e:
response = "{}\n{}".format(str(e), traceback.format_exc())
pytest.fail(
"Did not raise expected exception. Expected={}, Return={},".format(
expected_exception, response
)
)
pass

View file

@ -8978,6 +8978,75 @@ def exception_type(
response=original_exception.response,
litellm_debug_info=extra_information,
)
elif hasattr(original_exception, "status_code"):
if original_exception.status_code == 500:
exception_mapping_worked = True
raise litellm.InternalServerError(
message=f"PredibaseException - {original_exception.message}",
llm_provider="predibase",
model=model,
)
elif original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=f"PredibaseException - {original_exception.message}",
llm_provider="predibase",
model=model,
)
elif original_exception.status_code == 400:
exception_mapping_worked = True
raise BadRequestError(
message=f"PredibaseException - {original_exception.message}",
llm_provider="predibase",
model=model,
)
elif original_exception.status_code == 404:
exception_mapping_worked = True
raise NotFoundError(
message=f"PredibaseException - {original_exception.message}",
llm_provider="predibase",
model=model,
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
raise Timeout(
message=f"PredibaseException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 422:
exception_mapping_worked = True
raise BadRequestError(
message=f"PredibaseException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=f"PredibaseException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 503:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"PredibaseException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True
raise Timeout(
message=f"PredibaseException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
litellm_debug_info=extra_information,
)
elif custom_llm_provider == "bedrock":
if (
"too many tokens" in error_str

View file

@ -1,3 +1,3 @@
ignore = ["F405"]
ignore = ["F405", "E402"]
extend-select = ["E501"]
line-length = 120