refactor(proxy_server.py): print statement showing how to add debug for logs

This commit is contained in:
Krrish Dholakia 2023-11-03 17:41:02 -07:00
parent a2b9ffdd61
commit fa24a61976
8 changed files with 76 additions and 207 deletions

View file

@ -2,7 +2,7 @@ import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs'; import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem'; import TabItem from '@theme/TabItem';
# 💥 Evaluate LLMs - OpenAI Compatible Server # 💥 Evaluate LLMs - OpenAI Proxy Server
A simple, fast, and lightweight **OpenAI-compatible server** to call 100+ LLM APIs. A simple, fast, and lightweight **OpenAI-compatible server** to call 100+ LLM APIs.

View file

@ -47,6 +47,7 @@ client_session: Optional[requests.Session] = None
model_fallbacks: Optional[List] = None model_fallbacks: Optional[List] = None
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
num_retries: Optional[int] = None num_retries: Optional[int] = None
suppress_debug_info = False
############################################# #############################################
def get_model_cost_map(url: str): def get_model_cost_map(url: str):

View file

@ -255,7 +255,7 @@ def completion(
## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten) ## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
is_streamed = False is_streamed = False
if response.__dict__['headers']["Content-Type"] == "text/event-stream": if response.__dict__['headers'].get("Content-Type", "") == "text/event-stream":
is_streamed = True is_streamed = True
# iterate over the complete streamed response, and return the final answer # iterate over the complete streamed response, and return the final answer

View file

@ -0,0 +1,9 @@
model_list:
- model_name: zephyr-alpha
litellm_params: # params for litellm.completion() - https://docs.litellm.ai/docs/completion/input#input---request-body
model: huggingface/HuggingFaceH4/zephyr-7b-alpha
api_base: http://0.0.0.0:8001
- model_name: zephyr-beta
litellm_params:
model: huggingface/HuggingFaceH4/zephyr-7b-beta
api_base: https://<my-hosted-endpoint>

View file

@ -1,153 +0,0 @@
from typing import Dict, Optional
from collections import defaultdict
import threading
import os, subprocess, traceback, json
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
import backoff
import openai.error
import litellm
from litellm.utils import trim_messages
from litellm.exceptions import ServiceUnavailableError, InvalidRequestError
cost_dict: Dict[str, Dict[str, float]] = defaultdict(dict)
cost_dict_lock = threading.Lock()
debug = False
##### HELPER FUNCTIONS #####
def print_verbose(print_statement):
global debug
if debug:
print(print_statement)
# for streaming
def data_generator(response):
print_verbose("inside generator")
for chunk in response:
print_verbose(f"returned chunk: {chunk}")
yield f"data: {json.dumps(chunk)}\n\n"
def run_ollama_serve():
command = ['ollama', 'serve']
with open(os.devnull, 'w') as devnull:
process = subprocess.Popen(command, stdout=devnull, stderr=devnull)
##### ERROR HANDLING #####
class RetryConstantError(Exception):
pass
class RetryExpoError(Exception):
pass
class UnknownLLMError(Exception):
pass
def handle_llm_exception(e: Exception, user_api_base: Optional[str]=None):
print(f"\033[1;31mLiteLLM.Exception: {str(e)}\033[0m")
if isinstance(e, ServiceUnavailableError) and e.llm_provider == "ollama": # type: ignore
run_ollama_serve()
if isinstance(e, InvalidRequestError) and e.llm_provider == "ollama": # type: ignore
completion_call_details = {}
completion_call_details["model"] = e.model # type: ignore
if user_api_base:
completion_call_details["api_base"] = user_api_base
else:
completion_call_details["api_base"] = None
print(f"\033[1;31mLiteLLM.Exception: Invalid API Call. Call details: Model: \033[1;37m{e.model}\033[1;31m; LLM Provider: \033[1;37m{e.llm_provider}\033[1;31m; Custom API Base - \033[1;37m{completion_call_details['api_base']}\033[1;31m\033[0m") # type: ignore
if completion_call_details["api_base"] == "http://localhost:11434":
print()
print("Trying to call ollama? Try `litellm --model ollama/llama2 --api_base http://localhost:11434`")
print()
if isinstance(
e,
(
openai.error.APIError,
openai.error.TryAgain,
openai.error.Timeout,
openai.error.ServiceUnavailableError,
),
):
raise RetryConstantError from e
elif isinstance(e, openai.error.RateLimitError):
raise RetryExpoError from e
elif isinstance(
e,
(
openai.error.APIConnectionError,
openai.error.InvalidRequestError,
openai.error.AuthenticationError,
openai.error.PermissionError,
openai.error.InvalidAPIType,
openai.error.SignatureVerificationError,
),
):
raise e
else:
raise UnknownLLMError from e
@backoff.on_exception(
wait_gen=backoff.constant,
exception=RetryConstantError,
max_tries=3,
interval=3,
)
@backoff.on_exception(
wait_gen=backoff.expo,
exception=RetryExpoError,
jitter=backoff.full_jitter,
max_value=100,
factor=1.5,
)
def litellm_completion(data: Dict,
type: str,
user_model: Optional[str],
user_temperature: Optional[str],
user_max_tokens: Optional[int],
user_request_timeout: Optional[int],
user_api_base: Optional[str],
user_headers: Optional[dict],
user_debug: bool,
model_router: Optional[litellm.Router]):
try:
litellm.set_verbose=False
global debug
debug = user_debug
if user_model:
data["model"] = user_model
# override with user settings
if user_temperature:
data["temperature"] = user_temperature
if user_request_timeout:
data["request_timeout"] = user_request_timeout
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
if user_headers:
data["headers"] = user_headers
if type == "completion":
if model_router and data["model"] in model_router.get_model_names():
model_router.text_completion(**data)
else:
response = litellm.text_completion(**data)
elif type == "chat_completion":
if model_router and data["model"] in model_router.get_model_names():
model_router.completion(**data)
else:
response = litellm.completion(**data)
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
print_verbose(f"response: {response}")
return response
except Exception as e:
print(e)
handle_llm_exception(e=e, user_api_base=user_api_base)
return {"message": "An error occurred"}, 500

View file

@ -87,6 +87,7 @@ print("\033[1;34mDocs: https://docs.litellm.ai/docs/proxy_server\033[0m")
print() print()
import litellm import litellm
litellm.suppress_debug_info = True
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
@ -576,24 +577,29 @@ def model_list():
@router.post("/completions") @router.post("/completions")
@router.post("/engines/{model:path}/completions") @router.post("/engines/{model:path}/completions")
async def completion(request: Request, model: Optional[str] = None): async def completion(request: Request, model: Optional[str] = None):
body = await request.body() try:
body_str = body.decode() body = await request.body()
try: body_str = body.decode()
data = ast.literal_eval(body_str) try:
except: data = ast.literal_eval(body_str)
data = json.loads(body_str) except:
data["model"] = ( data = json.loads(body_str)
server_settings.get("completion_model", None) # server default data["model"] = (
or user_model # model name passed via cli args server_settings.get("completion_model", None) # server default
or model # for azure deployments or user_model # model name passed via cli args
or data["model"] # default passed in http request or model # for azure deployments
) or data["model"] # default passed in http request
if user_model: )
data["model"] = user_model if user_model:
data["call_type"] = "text_completion" data["model"] = user_model
return litellm_completion( data["call_type"] = "text_completion"
**data return litellm_completion(
) **data
)
except Exception as e:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
return {"error": error_msg}
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions")
@ -601,22 +607,28 @@ async def completion(request: Request, model: Optional[str] = None):
@router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint @router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None): async def chat_completion(request: Request, model: Optional[str] = None):
global server_settings global server_settings
body = await request.body() try:
body_str = body.decode() body = await request.body()
try: body_str = body.decode()
data = ast.literal_eval(body_str) try:
except: data = ast.literal_eval(body_str)
data = json.loads(body_str) except:
data["model"] = ( data = json.loads(body_str)
server_settings.get("completion_model", None) # server default data["model"] = (
or user_model # model name passed via cli args server_settings.get("completion_model", None) # server default
or model # for azure deployments or user_model # model name passed via cli args
or data["model"] # default passed in http request or model # for azure deployments
) or data["model"] # default passed in http request
data["call_type"] = "chat_completion" )
return litellm_completion( data["call_type"] = "chat_completion"
**data return litellm_completion(
) **data
)
except Exception as e:
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
return {"error": error_msg}
def print_cost_logs(): def print_cost_logs():
with open("costs.json", "r") as f: with open("costs.json", "r") as f:

View file

@ -219,20 +219,19 @@ def test_get_hf_task_for_model():
# ################### Hugging Face TGI models ######################## # ################### Hugging Face TGI models ########################
# # TGI model # # TGI model
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b # # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
# def hf_test_completion_tgi(): def hf_test_completion_tgi():
# litellm.huggingface_config(return_full_text=True) litellm.set_verbose=True
# litellm.set_verbose=True try:
# try: response = litellm.completion(
# response = litellm.completion( model="huggingface/mistralai/Mistral-7B-Instruct-v0.1",
# model="huggingface/mistralai/Mistral-7B-Instruct-v0.1", messages=[{ "content": "Hello, how are you?","role": "user"}],
# messages=[{ "content": "Hello, how are you?","role": "user"}], api_base="https://3kk3h56912qga4-80.proxy.runpod.net",
# api_base="https://n9ox93a8sv5ihsow.us-east-1.aws.endpoints.huggingface.cloud", )
# ) # Add any assertions here to check the response
# # Add any assertions here to check the response print(response)
# print(response) except Exception as e:
# except Exception as e: pytest.fail(f"Error occurred: {e}")
# pytest.fail(f"Error occurred: {e}") hf_test_completion_tgi()
# hf_test_completion_tgi()
# def hf_test_completion_tgi_stream(): # def hf_test_completion_tgi_stream():
# try: # try:

View file

@ -2757,10 +2757,11 @@ def exception_type(
): ):
global user_logger_fn, liteDebuggerClient global user_logger_fn, liteDebuggerClient
exception_mapping_worked = False exception_mapping_worked = False
print() if litellm.suppress_debug_info is False:
print("\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m") print()
print("LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.") print("\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m")
print() print("LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'.")
print()
try: try:
if isinstance(original_exception, OriginalError): if isinstance(original_exception, OriginalError):
# Handle the OpenAIError # Handle the OpenAIError