fix(proxy_server): improve error handling

This commit is contained in:
Krrish Dholakia 2023-10-16 19:42:53 -07:00
parent d5c33657d2
commit 541a8b7bc8
5 changed files with 166 additions and 55 deletions

141
litellm/proxy/llm.py Normal file
View file

@ -0,0 +1,141 @@
from typing import Dict, Optional
from collections import defaultdict
import threading
import os, subprocess, traceback, json
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
import backoff
import openai.error
import litellm
import litellm.exceptions
cost_dict: Dict[str, Dict[str, float]] = defaultdict(dict)
cost_dict_lock = threading.Lock()
debug = False
##### HELPER FUNCTIONS #####
def print_verbose(print_statement):
global debug
if debug:
print(print_statement)
# for streaming
def data_generator(response):
print_verbose("inside generator")
for chunk in response:
print_verbose(f"returned chunk: {chunk}")
yield f"data: {json.dumps(chunk)}\n\n"
def run_ollama_serve():
command = ['ollama', 'serve']
with open(os.devnull, 'w') as devnull:
process = subprocess.Popen(command, stdout=devnull, stderr=devnull)
##### ERROR HANDLING #####
class RetryConstantError(Exception):
pass
class RetryExpoError(Exception):
pass
class UnknownLLMError(Exception):
pass
def handle_llm_exception(e: Exception, user_api_base: Optional[str]=None):
print(f"\033[1;31mLiteLLM.Exception: {str(e)}\033[0m")
if isinstance(e, openai.error.ServiceUnavailableError) and e.llm_provider == "ollama":
run_ollama_serve()
if isinstance(e, openai.error.InvalidRequestError) and e.llm_provider == "ollama":
completion_call_details = {}
completion_call_details["model"] = e.model
if user_api_base:
completion_call_details["api_base"] = user_api_base
else:
completion_call_details["api_base"] = None
print(f"\033[1;31mLiteLLM.Exception: Invalid API Call. Call details: Model: \033[1;37m{e.model}\033[1;31m; LLM Provider: \033[1;37m{e.llm_provider}\033[1;31m; Custom API Base - \033[1;37m{completion_call_details['api_base']}\033[1;31m\033[0m")
if completion_call_details["api_base"] == "http://localhost:11434":
print()
print("Trying to call ollama? Try `litellm --model ollama/llama2 --api_base http://localhost:11434`")
print()
if isinstance(
e,
(
openai.error.APIError,
openai.error.TryAgain,
openai.error.Timeout,
openai.error.ServiceUnavailableError,
),
):
raise RetryConstantError from e
elif isinstance(e, openai.error.RateLimitError):
raise RetryExpoError from e
elif isinstance(
e,
(
openai.error.APIConnectionError,
openai.error.InvalidRequestError,
openai.error.AuthenticationError,
openai.error.PermissionError,
openai.error.InvalidAPIType,
openai.error.SignatureVerificationError,
),
):
raise e
else:
raise UnknownLLMError from e
@backoff.on_exception(
wait_gen=backoff.constant,
exception=RetryConstantError,
max_tries=3,
interval=3,
)
@backoff.on_exception(
wait_gen=backoff.expo,
exception=RetryExpoError,
jitter=backoff.full_jitter,
max_value=100,
factor=1.5,
)
def litellm_completion(data: Dict,
type: str,
user_model: Optional[str],
user_temperature: Optional[str],
user_max_tokens: Optional[int],
user_api_base: Optional[str],
user_headers: Optional[dict],
user_debug: bool) -> litellm.ModelResponse:
try:
global debug
debug = user_debug
if user_model:
data["model"] = user_model
# override with user settings
if user_temperature:
data["temperature"] = user_temperature
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
if user_headers:
data["headers"] = user_headers
if type == "completion":
response = litellm.text_completion(**data)
elif type == "chat_completion":
response = litellm.completion(**data)
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
print_verbose(f"response: {response}")
return response
except Exception as e:
print(e)
handle_llm_exception(e=e, user_api_base=user_api_base)
return {"message": "An error occurred"}, 500

View file

@ -23,6 +23,10 @@ except ImportError:
import appdirs import appdirs
import tomli_w import tomli_w
try:
from .llm import litellm_completion
except ImportError as e:
from llm import litellm_completion
import random import random
list_of_messages = [ list_of_messages = [
@ -305,14 +309,6 @@ def deploy_proxy(model, api_base, debug, temperature, max_tokens, telemetry, dep
return url return url
# for streaming
def data_generator(response):
print_verbose("inside generator")
for chunk in response:
print_verbose(f"returned chunk: {chunk}")
yield f"data: {json.dumps(chunk)}\n\n"
def track_cost_callback( def track_cost_callback(
kwargs, # kwargs to completion kwargs, # kwargs to completion
completion_response, # response from completion completion_response, # response from completion
@ -433,49 +429,6 @@ litellm.input_callback = [logger]
litellm.success_callback = [logger] litellm.success_callback = [logger]
litellm.failure_callback = [logger] litellm.failure_callback = [logger]
def litellm_completion(data, type):
try:
if user_model:
data["model"] = user_model
# override with user settings
if user_temperature:
data["temperature"] = user_temperature
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
if user_headers:
data["headers"] = user_headers
if type == "completion":
response = litellm.text_completion(**data)
elif type == "chat_completion":
response = litellm.completion(**data)
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
print_verbose(f"response: {response}")
return response
except Exception as e:
traceback.print_exc()
if "Invalid response object from API" in str(e):
completion_call_details = {}
if user_model:
completion_call_details["model"] = user_model
else:
completion_call_details["model"] = data['model']
if user_api_base:
completion_call_details["api_base"] = user_api_base
else:
completion_call_details["api_base"] = None
print(f"\033[1;31mLiteLLM.Exception: Invalid API Call. Call details: Model: \033[1;37m{completion_call_details['model']}\033[1;31m; LLM Provider: \033[1;37m{e.llm_provider}\033[1;31m; Custom API Base - \033[1;37m{completion_call_details['api_base']}\033[1;31m\033[0m")
if completion_call_details["api_base"] == "http://localhost:11434":
print()
print("Trying to call ollama? Try `litellm --model ollama/llama2 --api_base http://localhost:11434`")
print()
else:
print(f"\033[1;31mLiteLLM.Exception: {str(e)}\033[0m")
return {"message": "An error occurred"}, 500
#### API ENDPOINTS #### #### API ENDPOINTS ####
@router.get("/models") # if project requires model list @router.get("/models") # if project requires model list
def model_list(): def model_list():
@ -494,12 +447,12 @@ def model_list():
@router.post("/completions") @router.post("/completions")
async def completion(request: Request): async def completion(request: Request):
data = await request.json() data = await request.json()
return litellm_completion(data=data, type="completion") return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug)
@router.post("/chat/completions") @router.post("/chat/completions")
async def chat_completion(request: Request): async def chat_completion(request: Request):
data = await request.json() data = await request.json()
response = litellm_completion(data, type="chat_completion") response = litellm_completion(data, type="chat_completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug)
return response return response

View file

@ -3092,14 +3092,31 @@ def exception_type(
raise original_exception raise original_exception
raise original_exception raise original_exception
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
error_str = original_exception.get("error", "") if isinstance(original_exception, dict):
error_str = original_exception.get("error", "")
else:
error_str = str(original_exception)
if "no such file or directory" in error_str: if "no such file or directory" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise InvalidRequestError( raise InvalidRequestError(
message=f"Ollama Exception Invalid Model/Model not loaded - {original_exception}", message=f"OllamaException: Invalid Model/Model not loaded - {original_exception}",
model=model, model=model,
llm_provider="ollama" llm_provider="ollama"
) )
elif "Failed to establish a new connection" in error_str:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"OllamaException: {original_exception}",
llm_provider="ollama",
model=model
)
elif "Invalid response object from API" in error_str:
exception_mapping_worked = True
raise InvalidRequestError(
message=f"OllamaException: {original_exception}",
llm_provider="ollama",
model=model
)
elif custom_llm_provider == "vllm": elif custom_llm_provider == "vllm":
if hasattr(original_exception, "status_code"): if hasattr(original_exception, "status_code"):
if original_exception.status_code == 0: if original_exception.status_code == 0: