fix(proxy_server): returns better error messages for invalid api errors

This commit is contained in:
Krrish Dholakia 2023-10-09 14:56:05 -07:00
parent 262f874621
commit 42e0d7cf68
3 changed files with 67 additions and 72 deletions

View file

@ -118,6 +118,66 @@ def data_generator(response):
print_verbose(f"returned chunk: {chunk}") print_verbose(f"returned chunk: {chunk}")
yield f"data: {json.dumps(chunk)}\n\n" yield f"data: {json.dumps(chunk)}\n\n"
def litellm_completion(data, type):
try:
if user_model:
data["model"] = user_model
# override with user settings
if user_temperature:
data["temperature"] = user_temperature
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
## CUSTOM PROMPT TEMPLATE ## - run `litellm --config` to set this
litellm.register_prompt_template(
model=user_model,
roles={
"system": {
"pre_message": os.getenv("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""),
"post_message": os.getenv("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""),
},
"assistant": {
"pre_message": os.getenv("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""),
"post_message": os.getenv("MODEL_ASSISTANT_MESSAGE_END_TOKEN", "")
},
"user": {
"pre_message": os.getenv("MODEL_USER_MESSAGE_START_TOKEN", ""),
"post_message": os.getenv("MODEL_USER_MESSAGE_END_TOKEN", "")
}
},
initial_prompt_value=os.getenv("MODEL_PRE_PROMPT", ""),
final_prompt_value=os.getenv("MODEL_POST_PROMPT", "")
)
if type == "completion":
response = litellm.text_completion(**data)
elif type == "chat_completion":
response = litellm.completion(**data)
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
print_verbose(f"response: {response}")
return response
except Exception as e:
if "Invalid response object from API" in str(e):
completion_call_details = {}
if user_model:
completion_call_details["model"] = user_model
else:
completion_call_details["model"] = data['model']
if user_api_base:
completion_call_details["api_base"] = user_api_base
else:
completion_call_details["api_base"] = None
print(f"\033[1;31mLiteLLM.Exception: Invalid API Call. Call details: Model: \033[1;37m{completion_call_details['model']}\033[1;31m; LLM Provider: \033[1;37m{e.llm_provider}\033[1;31m; Custom API Base - \033[1;37m{completion_call_details['api_base']}\033[1;31m\033[0m")
if completion_call_details["api_base"] == "http://localhost:11434":
print()
print("Trying to call ollama? Try `litellm --model ollama/llama2 --api_base http://localhost:11434`")
print()
else:
print(f"\033[1;31mLiteLLM.Exception: {str(e)}\033[0m")
return {"message": "An error occurred"}, 500
#### API ENDPOINTS #### #### API ENDPOINTS ####
@router.get("/models") # if project requires model list @router.get("/models") # if project requires model list
def model_list(): def model_list():
@ -136,82 +196,15 @@ def model_list():
@router.post("/completions") @router.post("/completions")
async def completion(request: Request): async def completion(request: Request):
data = await request.json() data = await request.json()
print_verbose(f"data passed in: {data}") return litellm_completion(data=data, type="completion")
if user_model:
data["model"] = user_model
if user_api_base:
data["api_base"] = user_api_base
# override with user settings
if user_temperature:
data["temperature"] = user_temperature
if user_max_tokens:
data["max_tokens"] = user_max_tokens
## check for custom prompt template ##
litellm.register_prompt_template(
model=user_model,
roles={
"system": {
"pre_message": os.getenv("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""),
"post_message": os.getenv("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""),
},
"assistant": {
"pre_message": os.getenv("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""),
"post_message": os.getenv("MODEL_ASSISTANT_MESSAGE_END_TOKEN", "")
},
"user": {
"pre_message": os.getenv("MODEL_USER_MESSAGE_START_TOKEN", ""),
"post_message": os.getenv("MODEL_USER_MESSAGE_END_TOKEN", "")
}
},
initial_prompt_value=os.getenv("MODEL_PRE_PROMPT", ""),
final_prompt_value=os.getenv("MODEL_POST_PROMPT", "")
)
response = litellm.text_completion(**data)
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response
@router.post("/chat/completions") @router.post("/chat/completions")
async def chat_completion(request: Request): async def chat_completion(request: Request):
data = await request.json() data = await request.json()
print_verbose(f"data passed in: {data}") print_verbose(f"data passed in: {data}")
if user_model: response = litellm_completion(data, type="chat_completion")
data["model"] = user_model
# override with user settings
if user_temperature:
data["temperature"] = user_temperature
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
## check for custom prompt template ##
litellm.register_prompt_template(
model=user_model,
roles={
"system": {
"pre_message": os.getenv("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""),
"post_message": os.getenv("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""),
},
"assistant": {
"pre_message": os.getenv("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""),
"post_message": os.getenv("MODEL_ASSISTANT_MESSAGE_END_TOKEN", "")
},
"user": {
"pre_message": os.getenv("MODEL_USER_MESSAGE_START_TOKEN", ""),
"post_message": os.getenv("MODEL_USER_MESSAGE_END_TOKEN", "")
}
},
initial_prompt_value=os.getenv("MODEL_PRE_PROMPT", ""),
final_prompt_value=os.getenv("MODEL_POST_PROMPT", "")
)
response = litellm.completion(**data)
# track cost of this response, using litellm.completion_cost # track cost of this response, using litellm.completion_cost
await track_cost(response) track_cost(response)
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
print_verbose(f"response: {response}")
return response return response
async def track_cost(response): async def track_cost(response):

View file

@ -54,17 +54,19 @@ def test_completion_invalid_param_cohere():
else: else:
pytest.fail(f'An error occurred {e}') pytest.fail(f'An error occurred {e}')
test_completion_invalid_param_cohere() # test_completion_invalid_param_cohere()s
def test_completion_function_call_cohere(): def test_completion_function_call_cohere():
try: try:
response = completion(model="command-nightly", messages=messages, function_call="TEST-FUNCTION") response = completion(model="command-nightly", messages=messages, functions=["TEST-FUNCTION"])
except Exception as e: except Exception as e:
if "Function calling is not supported by this provider" in str(e): if "Function calling is not supported by this provider" in str(e):
pass pass
else: else:
pytest.fail(f'An error occurred {e}') pytest.fail(f'An error occurred {e}')
test_completion_function_call_cohere()
def test_completion_function_call_openai(): def test_completion_function_call_openai():
try: try:
messages = [{"role": "user", "content": "What is the weather like in Boston?"}] messages = [{"role": "user", "content": "What is the weather like in Boston?"}]