fix(proxy_server): improve error handling

This commit is contained in:
Krrish Dholakia 2023-10-16 19:42:53 -07:00
parent d5c33657d2
commit 541a8b7bc8
5 changed files with 166 additions and 55 deletions

View file

@ -23,6 +23,10 @@ except ImportError:
import appdirs
import tomli_w
try:
from .llm import litellm_completion
except ImportError as e:
from llm import litellm_completion
import random
list_of_messages = [
@ -305,14 +309,6 @@ def deploy_proxy(model, api_base, debug, temperature, max_tokens, telemetry, dep
return url
# for streaming
def data_generator(response):
print_verbose("inside generator")
for chunk in response:
print_verbose(f"returned chunk: {chunk}")
yield f"data: {json.dumps(chunk)}\n\n"
def track_cost_callback(
kwargs, # kwargs to completion
completion_response, # response from completion
@ -433,49 +429,6 @@ litellm.input_callback = [logger]
litellm.success_callback = [logger]
litellm.failure_callback = [logger]
def litellm_completion(data, type):
try:
if user_model:
data["model"] = user_model
# override with user settings
if user_temperature:
data["temperature"] = user_temperature
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
if user_headers:
data["headers"] = user_headers
if type == "completion":
response = litellm.text_completion(**data)
elif type == "chat_completion":
response = litellm.completion(**data)
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
print_verbose(f"response: {response}")
return response
except Exception as e:
traceback.print_exc()
if "Invalid response object from API" in str(e):
completion_call_details = {}
if user_model:
completion_call_details["model"] = user_model
else:
completion_call_details["model"] = data['model']
if user_api_base:
completion_call_details["api_base"] = user_api_base
else:
completion_call_details["api_base"] = None
print(f"\033[1;31mLiteLLM.Exception: Invalid API Call. Call details: Model: \033[1;37m{completion_call_details['model']}\033[1;31m; LLM Provider: \033[1;37m{e.llm_provider}\033[1;31m; Custom API Base - \033[1;37m{completion_call_details['api_base']}\033[1;31m\033[0m")
if completion_call_details["api_base"] == "http://localhost:11434":
print()
print("Trying to call ollama? Try `litellm --model ollama/llama2 --api_base http://localhost:11434`")
print()
else:
print(f"\033[1;31mLiteLLM.Exception: {str(e)}\033[0m")
return {"message": "An error occurred"}, 500
#### API ENDPOINTS ####
@router.get("/models") # if project requires model list
def model_list():
@ -494,12 +447,12 @@ def model_list():
@router.post("/completions")
async def completion(request: Request):
data = await request.json()
return litellm_completion(data=data, type="completion")
return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug)
@router.post("/chat/completions")
async def chat_completion(request: Request):
data = await request.json()
response = litellm_completion(data, type="chat_completion")
response = litellm_completion(data, type="chat_completion", user_model=user_model, user_temperature=user_temperature, user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug)
return response