Merge branch 'main' into litellm_webhook_support

This commit is contained in:
Krish Dholakia 2024-05-20 18:41:58 -07:00 committed by GitHub
commit 707cf24472
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 832 additions and 90 deletions

View file

@ -124,6 +124,7 @@ from litellm.proxy.auth.auth_checks import (
get_actual_routes,
)
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.exceptions import RejectedRequestError
try:
from litellm._version import version
@ -3649,7 +3650,6 @@ async def chat_completion(
):
global general_settings, user_debug, proxy_logging_obj, llm_model_list
data = {}
check_request_disconnected = None
try:
body = await request.body()
body_str = body.decode()
@ -3767,8 +3767,8 @@ async def chat_completion(
data["litellm_logging_obj"] = logging_obj
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook(
### CALL HOOKS ### - modify/reject incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
)
@ -3832,9 +3832,6 @@ async def chat_completion(
*tasks
) # run the moderation check in parallel to the actual llm api call
check_request_disconnected = asyncio.create_task(
check_request_disconnection(request, llm_responses)
)
responses = await llm_responses
response = responses[1]
@ -3886,6 +3883,40 @@ async def chat_completion(
)
return response
except RejectedRequestError as e:
_data = e.request_data
_data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
)
_chat_response = litellm.ModelResponse()
_chat_response.choices[0].message.content = e.message # type: ignore
if data.get("stream", None) is not None and data["stream"] == True:
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response, convert_to_delta=True
)
_streaming_response = litellm.CustomStreamWrapper(
completion_stream=_iterator,
model=data.get("model", ""),
custom_llm_provider="cached_response",
logging_obj=data.get("litellm_logging_obj", None),
)
selected_data_generator = select_data_generator(
response=_streaming_response,
user_api_key_dict=user_api_key_dict,
request_data=_data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
)
_usage = litellm.Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
_chat_response.usage = _usage # type: ignore
return _chat_response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
traceback.print_exc()
@ -3916,9 +3947,6 @@ async def chat_completion(
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
finally:
if check_request_disconnected is not None:
check_request_disconnected.cancel()
@router.post(
@ -3945,7 +3973,6 @@ async def completion(
):
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
data = {}
check_request_disconnected = None
try:
body = await request.body()
body_str = body.decode()
@ -4004,8 +4031,8 @@ async def completion(
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
)
### ROUTE THE REQUESTs ###
@ -4045,9 +4072,6 @@ async def completion(
+ data.get("model", "")
},
)
check_request_disconnected = asyncio.create_task(
check_request_disconnection(request, llm_response)
)
# Await the llm_response task
response = await llm_response
@ -4091,6 +4115,46 @@ async def completion(
)
return response
except RejectedRequestError as e:
_data = e.request_data
_data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
)
if _data.get("stream", None) is not None and _data["stream"] == True:
_chat_response = litellm.ModelResponse()
_usage = litellm.Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
_chat_response.usage = _usage # type: ignore
_chat_response.choices[0].message.content = e.message # type: ignore
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response, convert_to_delta=True
)
_streaming_response = litellm.TextCompletionStreamWrapper(
completion_stream=_iterator,
model=_data.get("model", ""),
)
selected_data_generator = select_data_generator(
response=_streaming_response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
headers={},
)
else:
_response = litellm.TextCompletionResponse()
_response.choices[0].text = e.message
return _response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
@ -4112,9 +4176,6 @@ async def completion(
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
finally:
if check_request_disconnected is not None:
check_request_disconnected.cancel()
@router.post(
@ -7761,6 +7822,12 @@ async def team_info(
team_info = await prisma_client.get_data(
team_id=team_id, table_name="team", query_type="find_unique"
)
if team_info is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"message": f"Team not found, passed team id: {team_id}."},
)
## GET ALL KEYS ##
keys = await prisma_client.get_data(
team_id=team_id,
@ -8993,9 +9060,25 @@ async def google_login(request: Request):
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
Example:
"""
global premium_user
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
####### Check if user is a Enterprise / Premium User #######
if (
microsoft_client_id is not None
or google_client_id is not None
or generic_client_id is not None
):
if premium_user != True:
raise ProxyException(
message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
type="auth_error",
param="premium_user",
code=status.HTTP_403_FORBIDDEN,
)
# get url from request
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
ui_username = os.getenv("UI_USERNAME")