forked from phoenix/litellm-mirror
fix(utils.py): improved predibase exception mapping
adds unit testing + better coverage for predibase errors
This commit is contained in:
parent
93a3a0cc1e
commit
1dafb1b3b7
11 changed files with 220 additions and 46 deletions
|
@ -20,7 +20,7 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
|
|||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: httpx.Response,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
|
@ -32,8 +32,14 @@ class AuthenticationError(openai.AuthenticationError): # type: ignore
|
|||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
self.response = response or httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
)
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
self.message, response=self.response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
|
@ -60,7 +66,7 @@ class NotFoundError(openai.NotFoundError): # type: ignore
|
|||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: httpx.Response,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
|
@ -72,8 +78,14 @@ class NotFoundError(openai.NotFoundError): # type: ignore
|
|||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
self.response = response or httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
)
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
self.message, response=self.response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
|
@ -262,7 +274,7 @@ class RateLimitError(openai.RateLimitError): # type: ignore
|
|||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: httpx.Response,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
|
@ -274,8 +286,18 @@ class RateLimitError(openai.RateLimitError): # type: ignore
|
|||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
if response is None:
|
||||
self.response = httpx.Response(
|
||||
status_code=429,
|
||||
request=httpx.Request(
|
||||
method="POST",
|
||||
url=" https://cloud.google.com/vertex-ai/",
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.response = response
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
self.message, response=self.response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
|
@ -421,7 +443,7 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
|||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: httpx.Response,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
|
@ -433,8 +455,18 @@ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
|||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
if response is None:
|
||||
self.response = httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
method="POST",
|
||||
url=" https://cloud.google.com/vertex-ai/",
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.response = response
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
self.message, response=self.response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
|
@ -460,7 +492,7 @@ class InternalServerError(openai.InternalServerError): # type: ignore
|
|||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: httpx.Response,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
|
@ -472,8 +504,18 @@ class InternalServerError(openai.InternalServerError): # type: ignore
|
|||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
if response is None:
|
||||
self.response = httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
method="POST",
|
||||
url=" https://cloud.google.com/vertex-ai/",
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.response = response
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
self.message, response=self.response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
|
|
0
litellm/integrations/test_httpx.py
Normal file
0
litellm/integrations/test_httpx.py
Normal file
|
@ -3,6 +3,7 @@
|
|||
|
||||
from functools import partial
|
||||
import os, types
|
||||
import traceback
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
|
@ -242,12 +243,12 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
"details" in completion_response
|
||||
and "tokens" in completion_response["details"]
|
||||
):
|
||||
model_response.choices[0].finish_reason = completion_response[
|
||||
"details"
|
||||
]["finish_reason"]
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["details"]["finish_reason"]
|
||||
)
|
||||
sum_logprob = 0
|
||||
for token in completion_response["details"]["tokens"]:
|
||||
if token["logprob"] != None:
|
||||
if token["logprob"] is not None:
|
||||
sum_logprob += token["logprob"]
|
||||
model_response["choices"][0][
|
||||
"message"
|
||||
|
@ -265,7 +266,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
):
|
||||
sum_logprob = 0
|
||||
for token in item["tokens"]:
|
||||
if token["logprob"] != None:
|
||||
if token["logprob"] is not None:
|
||||
sum_logprob += token["logprob"]
|
||||
if len(item["generated_text"]) > 0:
|
||||
message_obj = Message(
|
||||
|
@ -275,7 +276,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(
|
||||
finish_reason=item["finish_reason"],
|
||||
finish_reason=map_finish_reason(item["finish_reason"]),
|
||||
index=idx + 1,
|
||||
message=message_obj,
|
||||
)
|
||||
|
@ -285,10 +286,8 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
## CALCULATING USAGE
|
||||
prompt_tokens = 0
|
||||
try:
|
||||
prompt_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"]["content"])
|
||||
) ##[TODO] use a model-specific tokenizer here
|
||||
except:
|
||||
prompt_tokens = litellm.token_counter(messages=messages)
|
||||
except Exception:
|
||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||
pass
|
||||
output_text = model_response["choices"][0]["message"].get("content", "")
|
||||
|
@ -331,6 +330,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
logging_obj,
|
||||
optional_params: dict,
|
||||
tenant_id: str,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -340,6 +340,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
completion_url = ""
|
||||
input_text = ""
|
||||
base_url = "https://serving.app.predibase.com"
|
||||
|
||||
if "https" in model:
|
||||
completion_url = model
|
||||
elif api_base:
|
||||
|
@ -349,7 +350,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
|
||||
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}"
|
||||
|
||||
if optional_params.get("stream", False) == True:
|
||||
if optional_params.get("stream", False) is True:
|
||||
completion_url += "/generate_stream"
|
||||
else:
|
||||
completion_url += "/generate"
|
||||
|
@ -393,9 +394,9 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
if acompletion == True:
|
||||
if acompletion is True:
|
||||
### ASYNC STREAMING
|
||||
if stream == True:
|
||||
if stream is True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -410,6 +411,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
) # type: ignore
|
||||
else:
|
||||
### ASYNC COMPLETION
|
||||
|
@ -428,10 +430,11 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
) # type: ignore
|
||||
|
||||
### SYNC STREAMING
|
||||
if stream == True:
|
||||
if stream is True:
|
||||
response = requests.post(
|
||||
completion_url,
|
||||
headers=headers,
|
||||
|
@ -452,7 +455,6 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
)
|
||||
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
|
@ -480,23 +482,26 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
stream,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
) -> ModelResponse:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
|
||||
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
|
||||
try:
|
||||
response = await self.async_handler.post(
|
||||
response = await async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise PredibaseError(
|
||||
status_code=e.response.status_code, message=e.response.text
|
||||
status_code=e.response.status_code,
|
||||
message="HTTPStatusError - {}".format(e.response.text),
|
||||
)
|
||||
except Exception as e:
|
||||
raise PredibaseError(status_code=500, message=str(e))
|
||||
raise PredibaseError(
|
||||
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
|
||||
)
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
|
@ -522,6 +527,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
api_key,
|
||||
logging_obj,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
|
|
@ -432,9 +432,9 @@ def mock_completion(
|
|||
if isinstance(mock_response, openai.APIError):
|
||||
raise mock_response
|
||||
raise litellm.APIError(
|
||||
status_code=500, # type: ignore
|
||||
message=str(mock_response),
|
||||
llm_provider="openai", # type: ignore
|
||||
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
||||
message=getattr(mock_response, "text", str(mock_response)),
|
||||
llm_provider=getattr(mock_response, "llm_provider", "openai"), # type: ignore
|
||||
model=model, # type: ignore
|
||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||
)
|
||||
|
@ -1949,7 +1949,8 @@ def completion(
|
|||
)
|
||||
|
||||
api_base = (
|
||||
optional_params.pop("api_base", None)
|
||||
api_base
|
||||
or optional_params.pop("api_base", None)
|
||||
or optional_params.pop("base_url", None)
|
||||
or litellm.api_base
|
||||
or get_secret("PREDIBASE_API_BASE")
|
||||
|
@ -1977,12 +1978,13 @@ def completion(
|
|||
custom_prompt_dict=custom_prompt_dict,
|
||||
api_key=api_key,
|
||||
tenant_id=tenant_id,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and acompletion == False
|
||||
and optional_params["stream"] is True
|
||||
and acompletion is False
|
||||
):
|
||||
return _model_response
|
||||
response = _model_response
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -8,6 +8,17 @@ model_list:
|
|||
- model_name: llama3-70b-8192
|
||||
litellm_params:
|
||||
model: groq/llama3-70b-8192
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: predibase/llama-3-8b-instruct
|
||||
api_base: "http://0.0.0.0:8081"
|
||||
api_key: os.environ/PREDIBASE_API_KEY
|
||||
tenant_id: os.environ/PREDIBASE_TENANT_ID
|
||||
max_retries: 0
|
||||
temperature: 0.1
|
||||
max_new_tokens: 256
|
||||
return_full_text: false
|
||||
|
||||
# - litellm_params:
|
||||
# api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
# api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||
|
@ -56,10 +67,11 @@ router_settings:
|
|||
|
||||
litellm_settings:
|
||||
success_callback: ["langfuse"]
|
||||
failure_callback: ["langfuse"]
|
||||
|
||||
general_settings:
|
||||
alerting: ["email"]
|
||||
key_management_system: "aws_kms"
|
||||
key_management_settings:
|
||||
hosted_keys: ["LITELLM_MASTER_KEY"]
|
||||
# general_settings:
|
||||
# alerting: ["email"]
|
||||
# key_management_system: "aws_kms"
|
||||
# key_management_settings:
|
||||
# hosted_keys: ["LITELLM_MASTER_KEY"]
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import os
|
|||
import sys
|
||||
import traceback
|
||||
import subprocess, asyncio
|
||||
from typing import Any
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -19,6 +20,7 @@ from litellm import (
|
|||
)
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
litellm.vertex_project = "pathrise-convert-1606954137718"
|
||||
litellm.vertex_location = "us-central1"
|
||||
|
@ -655,3 +657,47 @@ def test_litellm_predibase_exception():
|
|||
|
||||
# accuracy_score = counts[True]/(counts[True] + counts[False])
|
||||
# print(f"accuracy_score: {accuracy_score}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["predibase"])
|
||||
def test_exception_mapping(provider):
|
||||
"""
|
||||
For predibase, run through a set of mock exceptions
|
||||
|
||||
assert that they are being mapped correctly
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
error_map = {
|
||||
400: litellm.BadRequestError,
|
||||
401: litellm.AuthenticationError,
|
||||
404: litellm.NotFoundError,
|
||||
408: litellm.Timeout,
|
||||
429: litellm.RateLimitError,
|
||||
500: litellm.InternalServerError,
|
||||
503: litellm.ServiceUnavailableError,
|
||||
}
|
||||
|
||||
for code, expected_exception in error_map.items():
|
||||
mock_response = Exception()
|
||||
setattr(mock_response, "text", "This is an error message")
|
||||
setattr(mock_response, "llm_provider", provider)
|
||||
setattr(mock_response, "status_code", code)
|
||||
|
||||
response: Any = None
|
||||
try:
|
||||
response = completion(
|
||||
model="{}/test-model".format(provider),
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
mock_response=mock_response,
|
||||
)
|
||||
except expected_exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
response = "{}\n{}".format(str(e), traceback.format_exc())
|
||||
pytest.fail(
|
||||
"Did not raise expected exception. Expected={}, Return={},".format(
|
||||
expected_exception, response
|
||||
)
|
||||
)
|
||||
|
||||
pass
|
||||
|
|
|
@ -8978,6 +8978,75 @@ def exception_type(
|
|||
response=original_exception.response,
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
elif hasattr(original_exception, "status_code"):
|
||||
if original_exception.status_code == 500:
|
||||
exception_mapping_worked = True
|
||||
raise litellm.InternalServerError(
|
||||
message=f"PredibaseException - {original_exception.message}",
|
||||
llm_provider="predibase",
|
||||
model=model,
|
||||
)
|
||||
elif original_exception.status_code == 401:
|
||||
exception_mapping_worked = True
|
||||
raise AuthenticationError(
|
||||
message=f"PredibaseException - {original_exception.message}",
|
||||
llm_provider="predibase",
|
||||
model=model,
|
||||
)
|
||||
elif original_exception.status_code == 400:
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
message=f"PredibaseException - {original_exception.message}",
|
||||
llm_provider="predibase",
|
||||
model=model,
|
||||
)
|
||||
elif original_exception.status_code == 404:
|
||||
exception_mapping_worked = True
|
||||
raise NotFoundError(
|
||||
message=f"PredibaseException - {original_exception.message}",
|
||||
llm_provider="predibase",
|
||||
model=model,
|
||||
)
|
||||
elif original_exception.status_code == 408:
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
message=f"PredibaseException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
elif original_exception.status_code == 422:
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
message=f"PredibaseException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
elif original_exception.status_code == 429:
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
message=f"PredibaseException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
elif original_exception.status_code == 503:
|
||||
exception_mapping_worked = True
|
||||
raise ServiceUnavailableError(
|
||||
message=f"PredibaseException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
elif original_exception.status_code == 504: # gateway timeout error
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
message=f"PredibaseException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider=custom_llm_provider,
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
if (
|
||||
"too many tokens" in error_str
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
ignore = ["F405"]
|
||||
ignore = ["F405", "E402"]
|
||||
extend-select = ["E501"]
|
||||
line-length = 120
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue