forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_webhook_support
This commit is contained in:
commit
707cf24472
19 changed files with 832 additions and 90 deletions
|
@ -124,6 +124,7 @@ from litellm.proxy.auth.auth_checks import (
|
|||
get_actual_routes,
|
||||
)
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from litellm.exceptions import RejectedRequestError
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
|
@ -3649,7 +3650,6 @@ async def chat_completion(
|
|||
):
|
||||
global general_settings, user_debug, proxy_logging_obj, llm_model_list
|
||||
data = {}
|
||||
check_request_disconnected = None
|
||||
try:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
|
@ -3767,8 +3767,8 @@ async def chat_completion(
|
|||
|
||||
data["litellm_logging_obj"] = logging_obj
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(
|
||||
### CALL HOOKS ### - modify/reject incoming data before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
||||
)
|
||||
|
||||
|
@ -3832,9 +3832,6 @@ async def chat_completion(
|
|||
*tasks
|
||||
) # run the moderation check in parallel to the actual llm api call
|
||||
|
||||
check_request_disconnected = asyncio.create_task(
|
||||
check_request_disconnection(request, llm_responses)
|
||||
)
|
||||
responses = await llm_responses
|
||||
|
||||
response = responses[1]
|
||||
|
@ -3886,6 +3883,40 @@ async def chat_completion(
|
|||
)
|
||||
|
||||
return response
|
||||
except RejectedRequestError as e:
|
||||
_data = e.request_data
|
||||
_data["litellm_status"] = "fail" # used for alerting
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=_data,
|
||||
)
|
||||
_chat_response = litellm.ModelResponse()
|
||||
_chat_response.choices[0].message.content = e.message # type: ignore
|
||||
|
||||
if data.get("stream", None) is not None and data["stream"] == True:
|
||||
_iterator = litellm.utils.ModelResponseIterator(
|
||||
model_response=_chat_response, convert_to_delta=True
|
||||
)
|
||||
_streaming_response = litellm.CustomStreamWrapper(
|
||||
completion_stream=_iterator,
|
||||
model=data.get("model", ""),
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=data.get("litellm_logging_obj", None),
|
||||
)
|
||||
selected_data_generator = select_data_generator(
|
||||
response=_streaming_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=_data,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
_usage = litellm.Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
_chat_response.usage = _usage # type: ignore
|
||||
return _chat_response
|
||||
except Exception as e:
|
||||
data["litellm_status"] = "fail" # used for alerting
|
||||
traceback.print_exc()
|
||||
|
@ -3916,9 +3947,6 @@ async def chat_completion(
|
|||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
finally:
|
||||
if check_request_disconnected is not None:
|
||||
check_request_disconnected.cancel()
|
||||
|
||||
|
||||
@router.post(
|
||||
|
@ -3945,7 +3973,6 @@ async def completion(
|
|||
):
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
data = {}
|
||||
check_request_disconnected = None
|
||||
try:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
|
@ -4004,8 +4031,8 @@ async def completion(
|
|||
data["model"] = litellm.model_alias_map[data["model"]]
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
||||
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
|
||||
)
|
||||
|
||||
### ROUTE THE REQUESTs ###
|
||||
|
@ -4045,9 +4072,6 @@ async def completion(
|
|||
+ data.get("model", "")
|
||||
},
|
||||
)
|
||||
check_request_disconnected = asyncio.create_task(
|
||||
check_request_disconnection(request, llm_response)
|
||||
)
|
||||
|
||||
# Await the llm_response task
|
||||
response = await llm_response
|
||||
|
@ -4091,6 +4115,46 @@ async def completion(
|
|||
)
|
||||
|
||||
return response
|
||||
except RejectedRequestError as e:
|
||||
_data = e.request_data
|
||||
_data["litellm_status"] = "fail" # used for alerting
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=_data,
|
||||
)
|
||||
if _data.get("stream", None) is not None and _data["stream"] == True:
|
||||
_chat_response = litellm.ModelResponse()
|
||||
_usage = litellm.Usage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
_chat_response.usage = _usage # type: ignore
|
||||
_chat_response.choices[0].message.content = e.message # type: ignore
|
||||
_iterator = litellm.utils.ModelResponseIterator(
|
||||
model_response=_chat_response, convert_to_delta=True
|
||||
)
|
||||
_streaming_response = litellm.TextCompletionStreamWrapper(
|
||||
completion_stream=_iterator,
|
||||
model=_data.get("model", ""),
|
||||
)
|
||||
|
||||
selected_data_generator = select_data_generator(
|
||||
response=_streaming_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
headers={},
|
||||
)
|
||||
else:
|
||||
_response = litellm.TextCompletionResponse()
|
||||
_response.choices[0].text = e.message
|
||||
return _response
|
||||
except Exception as e:
|
||||
data["litellm_status"] = "fail" # used for alerting
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
|
@ -4112,9 +4176,6 @@ async def completion(
|
|||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
finally:
|
||||
if check_request_disconnected is not None:
|
||||
check_request_disconnected.cancel()
|
||||
|
||||
|
||||
@router.post(
|
||||
|
@ -7761,6 +7822,12 @@ async def team_info(
|
|||
team_info = await prisma_client.get_data(
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
if team_info is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={"message": f"Team not found, passed team id: {team_id}."},
|
||||
)
|
||||
|
||||
## GET ALL KEYS ##
|
||||
keys = await prisma_client.get_data(
|
||||
team_id=team_id,
|
||||
|
@ -8993,9 +9060,25 @@ async def google_login(request: Request):
|
|||
PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
|
||||
Example:
|
||||
"""
|
||||
global premium_user
|
||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
|
||||
####### Check if user is a Enterprise / Premium User #######
|
||||
if (
|
||||
microsoft_client_id is not None
|
||||
or google_client_id is not None
|
||||
or generic_client_id is not None
|
||||
):
|
||||
if premium_user != True:
|
||||
raise ProxyException(
|
||||
message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this",
|
||||
type="auth_error",
|
||||
param="premium_user",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
# get url from request
|
||||
redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
|
||||
ui_username = os.getenv("UI_USERNAME")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue