fix(utils.py): improved predibase exception mapping

adds unit testing + better coverage for predibase errors
This commit is contained in:
Krrish Dholakia 2024-06-08 14:32:43 -07:00
parent 93a3a0cc1e
commit 1dafb1b3b7
11 changed files with 220 additions and 46 deletions

View file

@ -20,7 +20,7 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
message, message,
llm_provider, llm_provider,
model, model,
response: httpx.Response, response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
@ -32,8 +32,14 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries self.max_retries = max_retries
self.num_retries = num_retries self.num_retries = num_retries
self.response = response or httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
)
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self): def __str__(self):
@ -60,7 +66,7 @@ class NotFoundError(openai.NotFoundError): # type: ignore
message, message,
model, model,
llm_provider, llm_provider,
response: httpx.Response, response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
@ -72,8 +78,14 @@ class NotFoundError(openai.NotFoundError): # type: ignore
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries self.max_retries = max_retries
self.num_retries = num_retries self.num_retries = num_retries
self.response = response or httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
)
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self): def __str__(self):
@ -262,7 +274,7 @@ class RateLimitError(openai.RateLimitError): # type: ignore
message, message,
llm_provider, llm_provider,
model, model,
response: httpx.Response, response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
@ -274,8 +286,18 @@ class RateLimitError(openai.RateLimitError): # type: ignore
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries self.max_retries = max_retries
self.num_retries = num_retries self.num_retries = num_retries
if response is None:
self.response = httpx.Response(
status_code=429,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
)
else:
self.response = response
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self): def __str__(self):
@ -421,7 +443,7 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
message, message,
llm_provider, llm_provider,
model, model,
response: httpx.Response, response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
@ -433,8 +455,18 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries self.max_retries = max_retries
self.num_retries = num_retries self.num_retries = num_retries
if response is None:
self.response = httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
)
else:
self.response = response
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self): def __str__(self):
@ -460,7 +492,7 @@ class InternalServerError(openai.InternalServerError): # type: ignore
message, message,
llm_provider, llm_provider,
model, model,
response: httpx.Response, response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None, litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
num_retries: Optional[int] = None, num_retries: Optional[int] = None,
@ -472,8 +504,18 @@ class InternalServerError(openai.InternalServerError): # type: ignore
self.litellm_debug_info = litellm_debug_info self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries self.max_retries = max_retries
self.num_retries = num_retries self.num_retries = num_retries
if response is None:
self.response = httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
)
else:
self.response = response
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def __str__(self): def __str__(self):

View file

View file

@ -3,6 +3,7 @@
from functools import partial from functools import partial
import os, types import os, types
import traceback
import json import json
from enum import Enum from enum import Enum
import requests, copy # type: ignore import requests, copy # type: ignore
@ -242,12 +243,12 @@ class PredibaseChatCompletion(BaseLLM):
"details" in completion_response "details" in completion_response
and "tokens" in completion_response["details"] and "tokens" in completion_response["details"]
): ):
model_response.choices[0].finish_reason = completion_response[ model_response.choices[0].finish_reason = map_finish_reason(
"details" completion_response["details"]["finish_reason"]
]["finish_reason"] )
sum_logprob = 0 sum_logprob = 0
for token in completion_response["details"]["tokens"]: for token in completion_response["details"]["tokens"]:
if token["logprob"] != None: if token["logprob"] is not None:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
model_response["choices"][0][ model_response["choices"][0][
"message" "message"
@ -265,7 +266,7 @@ class PredibaseChatCompletion(BaseLLM):
): ):
sum_logprob = 0 sum_logprob = 0
for token in item["tokens"]: for token in item["tokens"]:
if token["logprob"] != None: if token["logprob"] is not None:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0: if len(item["generated_text"]) > 0:
message_obj = Message( message_obj = Message(
@ -275,7 +276,7 @@ class PredibaseChatCompletion(BaseLLM):
else: else:
message_obj = Message(content=None) message_obj = Message(content=None)
choice_obj = Choices( choice_obj = Choices(
finish_reason=item["finish_reason"], finish_reason=map_finish_reason(item["finish_reason"]),
index=idx + 1, index=idx + 1,
message=message_obj, message=message_obj,
) )
@ -285,10 +286,8 @@ class PredibaseChatCompletion(BaseLLM):
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = 0 prompt_tokens = 0
try: try:
prompt_tokens = len( prompt_tokens = litellm.token_counter(messages=messages)
encoding.encode(model_response["choices"][0]["message"]["content"]) except Exception:
) ##[TODO] use a model-specific tokenizer here
except:
# this should remain non blocking we should not block a response returning if calculating usage fails # this should remain non blocking we should not block a response returning if calculating usage fails
pass pass
output_text = model_response["choices"][0]["message"].get("content", "") output_text = model_response["choices"][0]["message"].get("content", "")
@ -331,6 +330,7 @@ class PredibaseChatCompletion(BaseLLM):
logging_obj, logging_obj,
optional_params: dict, optional_params: dict,
tenant_id: str, tenant_id: str,
timeout: Union[float, httpx.Timeout],
acompletion=None, acompletion=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -340,6 +340,7 @@ class PredibaseChatCompletion(BaseLLM):
completion_url = "" completion_url = ""
input_text = "" input_text = ""
base_url = "https://serving.app.predibase.com" base_url = "https://serving.app.predibase.com"
if "https" in model: if "https" in model:
completion_url = model completion_url = model
elif api_base: elif api_base:
@ -349,7 +350,7 @@ class PredibaseChatCompletion(BaseLLM):
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}" completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) is True:
completion_url += "/generate_stream" completion_url += "/generate_stream"
else: else:
completion_url += "/generate" completion_url += "/generate"
@ -393,9 +394,9 @@ class PredibaseChatCompletion(BaseLLM):
}, },
) )
## COMPLETION CALL ## COMPLETION CALL
if acompletion == True: if acompletion is True:
### ASYNC STREAMING ### ASYNC STREAMING
if stream == True: if stream is True:
return self.async_streaming( return self.async_streaming(
model=model, model=model,
messages=messages, messages=messages,
@ -410,6 +411,7 @@ class PredibaseChatCompletion(BaseLLM):
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
headers=headers, headers=headers,
timeout=timeout,
) # type: ignore ) # type: ignore
else: else:
### ASYNC COMPLETION ### ASYNC COMPLETION
@ -428,10 +430,11 @@ class PredibaseChatCompletion(BaseLLM):
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
headers=headers, headers=headers,
timeout=timeout,
) # type: ignore ) # type: ignore
### SYNC STREAMING ### SYNC STREAMING
if stream == True: if stream is True:
response = requests.post( response = requests.post(
completion_url, completion_url,
headers=headers, headers=headers,
@ -452,7 +455,6 @@ class PredibaseChatCompletion(BaseLLM):
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
) )
return self.process_response( return self.process_response(
model=model, model=model,
response=response, response=response,
@ -480,23 +482,26 @@ class PredibaseChatCompletion(BaseLLM):
stream, stream,
data: dict, data: dict,
optional_params: dict, optional_params: dict,
timeout: Union[float, httpx.Timeout],
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
headers={}, headers={},
) -> ModelResponse: ) -> ModelResponse:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0) async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
)
try: try:
response = await self.async_handler.post( response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data) api_base, headers=headers, data=json.dumps(data)
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
raise PredibaseError( raise PredibaseError(
status_code=e.response.status_code, message=e.response.text status_code=e.response.status_code,
message="HTTPStatusError - {}".format(e.response.text),
) )
except Exception as e: except Exception as e:
raise PredibaseError(status_code=500, message=str(e)) raise PredibaseError(
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
)
return self.process_response( return self.process_response(
model=model, model=model,
response=response, response=response,
@ -522,6 +527,7 @@ class PredibaseChatCompletion(BaseLLM):
api_key, api_key,
logging_obj, logging_obj,
data: dict, data: dict,
timeout: Union[float, httpx.Timeout],
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,

View file

@ -432,9 +432,9 @@ def mock_completion(
if isinstance(mock_response, openai.APIError): if isinstance(mock_response, openai.APIError):
raise mock_response raise mock_response
raise litellm.APIError( raise litellm.APIError(
status_code=500, # type: ignore status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=str(mock_response), message=getattr(mock_response, "text", str(mock_response)),
llm_provider="openai", # type: ignore llm_provider=getattr(mock_response, "llm_provider", "openai"), # type: ignore
model=model, # type: ignore model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"), request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
) )
@ -1949,7 +1949,8 @@ def completion(
) )
api_base = ( api_base = (
optional_params.pop("api_base", None) api_base
or optional_params.pop("api_base", None)
or optional_params.pop("base_url", None) or optional_params.pop("base_url", None)
or litellm.api_base or litellm.api_base
or get_secret("PREDIBASE_API_BASE") or get_secret("PREDIBASE_API_BASE")
@ -1977,12 +1978,13 @@ def completion(
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
api_key=api_key, api_key=api_key,
tenant_id=tenant_id, tenant_id=tenant_id,
timeout=timeout,
) )
if ( if (
"stream" in optional_params "stream" in optional_params
and optional_params["stream"] == True and optional_params["stream"] is True
and acompletion == False and acompletion is False
): ):
return _model_response return _model_response
response = _model_response response = _model_response

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -8,6 +8,17 @@ model_list:
- model_name: llama3-70b-8192 - model_name: llama3-70b-8192
litellm_params: litellm_params:
model: groq/llama3-70b-8192 model: groq/llama3-70b-8192
- model_name: fake-openai-endpoint
litellm_params:
model: predibase/llama-3-8b-instruct
api_base: "http://0.0.0.0:8081"
api_key: os.environ/PREDIBASE_API_KEY
tenant_id: os.environ/PREDIBASE_TENANT_ID
max_retries: 0
temperature: 0.1
max_new_tokens: 256
return_full_text: false
# - litellm_params: # - litellm_params:
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ # api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
# api_key: os.environ/AZURE_EUROPE_API_KEY # api_key: os.environ/AZURE_EUROPE_API_KEY
@ -56,10 +67,11 @@ router_settings:
litellm_settings: litellm_settings:
success_callback: ["langfuse"] success_callback: ["langfuse"]
failure_callback: ["langfuse"]
general_settings: # general_settings:
alerting: ["email"] # alerting: ["email"]
key_management_system: "aws_kms" # key_management_system: "aws_kms"
key_management_settings: # key_management_settings:
hosted_keys: ["LITELLM_MASTER_KEY"] # hosted_keys: ["LITELLM_MASTER_KEY"]

View file

@ -3,6 +3,7 @@ import os
import sys import sys
import traceback import traceback
import subprocess, asyncio import subprocess, asyncio
from typing import Any
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -19,6 +20,7 @@ from litellm import (
) )
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import pytest import pytest
from unittest.mock import patch, MagicMock
litellm.vertex_project = "pathrise-convert-1606954137718" litellm.vertex_project = "pathrise-convert-1606954137718"
litellm.vertex_location = "us-central1" litellm.vertex_location = "us-central1"
@ -655,3 +657,47 @@ def test_litellm_predibase_exception():
# accuracy_score = counts[True]/(counts[True] + counts[False]) # accuracy_score = counts[True]/(counts[True] + counts[False])
# print(f"accuracy_score: {accuracy_score}") # print(f"accuracy_score: {accuracy_score}")
@pytest.mark.parametrize("provider", ["predibase"])
def test_exception_mapping(provider):
"""
For predibase, run through a set of mock exceptions
assert that they are being mapped correctly
"""
litellm.set_verbose = True
error_map = {
400: litellm.BadRequestError,
401: litellm.AuthenticationError,
404: litellm.NotFoundError,
408: litellm.Timeout,
429: litellm.RateLimitError,
500: litellm.InternalServerError,
503: litellm.ServiceUnavailableError,
}
for code, expected_exception in error_map.items():
mock_response = Exception()
setattr(mock_response, "text", "This is an error message")
setattr(mock_response, "llm_provider", provider)
setattr(mock_response, "status_code", code)
response: Any = None
try:
response = completion(
model="{}/test-model".format(provider),
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_response=mock_response,
)
except expected_exception:
continue
except Exception as e:
response = "{}\n{}".format(str(e), traceback.format_exc())
pytest.fail(
"Did not raise expected exception. Expected={}, Return={},".format(
expected_exception, response
)
)
pass

View file

@ -8978,6 +8978,75 @@ def exception_type(
response=original_exception.response, response=original_exception.response,
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif hasattr(original_exception, "status_code"):
if original_exception.status_code == 500:
exception_mapping_worked = True
raise litellm.InternalServerError(
message=f"PredibaseException - {original_exception.message}",
llm_provider="predibase",
model=model,
)
elif original_exception.status_code == 401:
exception_mapping_worked = True
raise AuthenticationError(
message=f"PredibaseException - {original_exception.message}",
llm_provider="predibase",
model=model,
)
elif original_exception.status_code == 400:
exception_mapping_worked = True
raise BadRequestError(
message=f"PredibaseException - {original_exception.message}",
llm_provider="predibase",
model=model,
)
elif original_exception.status_code == 404:
exception_mapping_worked = True
raise NotFoundError(
message=f"PredibaseException - {original_exception.message}",
llm_provider="predibase",
model=model,
)
elif original_exception.status_code == 408:
exception_mapping_worked = True
raise Timeout(
message=f"PredibaseException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 422:
exception_mapping_worked = True
raise BadRequestError(
message=f"PredibaseException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 429:
exception_mapping_worked = True
raise RateLimitError(
message=f"PredibaseException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 503:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"PredibaseException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
litellm_debug_info=extra_information,
)
elif original_exception.status_code == 504: # gateway timeout error
exception_mapping_worked = True
raise Timeout(
message=f"PredibaseException - {original_exception.message}",
model=model,
llm_provider=custom_llm_provider,
litellm_debug_info=extra_information,
)
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
if ( if (
"too many tokens" in error_str "too many tokens" in error_str

View file

@ -1,3 +1,3 @@
ignore = ["F405"] ignore = ["F405", "E402"]
extend-select = ["E501"] extend-select = ["E501"]
line-length = 120 line-length = 120