mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(proxy_server.py): add call hooks pre+post completion and embedding calls
This commit is contained in:
parent
dfba305508
commit
f5afc429b3
2 changed files with 59 additions and 4 deletions
|
@ -93,7 +93,8 @@ def generate_feedback_box():
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.proxy.utils import (
|
from litellm.proxy.utils import (
|
||||||
PrismaClient,
|
PrismaClient,
|
||||||
get_instance_fn
|
get_instance_fn,
|
||||||
|
CallHooks
|
||||||
)
|
)
|
||||||
import pydantic
|
import pydantic
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
|
@ -196,6 +197,7 @@ user_custom_auth = None
|
||||||
use_background_health_checks = None
|
use_background_health_checks = None
|
||||||
health_check_interval = None
|
health_check_interval = None
|
||||||
health_check_results = {}
|
health_check_results = {}
|
||||||
|
call_hooks = CallHooks()
|
||||||
### REDIS QUEUE ###
|
### REDIS QUEUE ###
|
||||||
async_result = None
|
async_result = None
|
||||||
celery_app_conn = None
|
celery_app_conn = None
|
||||||
|
@ -570,6 +572,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
litellm_model_name = model["litellm_params"]["model"]
|
litellm_model_name = model["litellm_params"]["model"]
|
||||||
if "ollama" in litellm_model_name:
|
if "ollama" in litellm_model_name:
|
||||||
run_ollama_serve()
|
run_ollama_serve()
|
||||||
|
|
||||||
|
call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings)
|
||||||
return router, model_list, general_settings
|
return router, model_list, general_settings
|
||||||
|
|
||||||
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None):
|
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None):
|
||||||
|
@ -740,9 +744,13 @@ def data_generator(response):
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
async def async_data_generator(response):
|
async def async_data_generator(response):
|
||||||
|
global call_hooks
|
||||||
|
|
||||||
print_verbose("inside generator")
|
print_verbose("inside generator")
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print_verbose(f"returned chunk: {chunk}")
|
print_verbose(f"returned chunk: {chunk}")
|
||||||
|
### CALL HOOKS ### - modify outgoing response
|
||||||
|
response = call_hooks.post_call_success(chunk=chunk, call_type="completion")
|
||||||
try:
|
try:
|
||||||
yield f"data: {json.dumps(chunk.dict())}\n\n"
|
yield f"data: {json.dumps(chunk.dict())}\n\n"
|
||||||
except:
|
except:
|
||||||
|
@ -941,7 +949,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
||||||
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
||||||
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
|
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
|
||||||
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||||
global general_settings, user_debug
|
global general_settings, user_debug, call_hooks
|
||||||
try:
|
try:
|
||||||
data = {}
|
data = {}
|
||||||
data = await request.json() # type: ignore
|
data = await request.json() # type: ignore
|
||||||
|
@ -977,6 +985,11 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
||||||
data["max_tokens"] = user_max_tokens
|
data["max_tokens"] = user_max_tokens
|
||||||
if user_api_base:
|
if user_api_base:
|
||||||
data["api_base"] = user_api_base
|
data["api_base"] = user_api_base
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify incoming data before calling the model
|
||||||
|
data = call_hooks.pre_call(data=data, call_type="completion")
|
||||||
|
|
||||||
|
### ROUTE THE REQUEST ###
|
||||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||||
response = await llm_router.acompletion(**data)
|
response = await llm_router.acompletion(**data)
|
||||||
|
@ -986,8 +999,13 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
||||||
response = await llm_router.acompletion(**data)
|
response = await llm_router.acompletion(**data)
|
||||||
else: # router is not set
|
else: # router is not set
|
||||||
response = await litellm.acompletion(**data)
|
response = await litellm.acompletion(**data)
|
||||||
|
|
||||||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||||
return StreamingResponse(async_data_generator(response), media_type='text/event-stream')
|
return StreamingResponse(async_data_generator(response), media_type='text/event-stream')
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify outgoing response
|
||||||
|
response = call_hooks.post_call_success(response=response, call_type="completion")
|
||||||
|
|
||||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1022,6 +1040,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
||||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||||
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||||
try:
|
try:
|
||||||
|
global call_hooks
|
||||||
|
|
||||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
@ -1056,6 +1075,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
||||||
input_list.append(litellm.decode(model="gpt-3.5-turbo", tokens=i))
|
input_list.append(litellm.decode(model="gpt-3.5-turbo", tokens=i))
|
||||||
data["input"] = input_list
|
data["input"] = input_list
|
||||||
break
|
break
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify incoming data before calling the model
|
||||||
|
data = call_hooks.pre_call(data=data, call_type="embeddings")
|
||||||
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||||
|
@ -1065,6 +1087,10 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
||||||
else:
|
else:
|
||||||
response = await litellm.aembedding(**data)
|
response = await litellm.aembedding(**data)
|
||||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify outgoing response
|
||||||
|
data = call_hooks.post_call_success(response=response, call_type="embeddings")
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, List, Any
|
from typing import Optional, List, Any, Literal
|
||||||
import os, subprocess, hashlib, importlib
|
import os, subprocess, hashlib, importlib
|
||||||
|
|
||||||
### DB CONNECTOR ###
|
### DB CONNECTOR ###
|
||||||
|
@ -131,4 +131,33 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||||
# Re-raise the exception with a user-friendly message
|
# Re-raise the exception with a user-friendly message
|
||||||
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
|
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
### CALL HOOKS ###
|
||||||
|
|
||||||
|
class CallHooks:
|
||||||
|
"""
|
||||||
|
Allows users to modify the incoming request / output to the proxy, without having to deal with parsing Request body.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
1. /chat/completions
|
||||||
|
2. /embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.call_details = {}
|
||||||
|
|
||||||
|
def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list):
|
||||||
|
self.call_details["litellm_settings"] = litellm_settings
|
||||||
|
self.call_details["general_settings"] = general_settings
|
||||||
|
self.call_details["model_list"] = model_list
|
||||||
|
|
||||||
|
def pre_call(self, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||||
|
self.call_details["data"] = data
|
||||||
|
return data
|
||||||
|
|
||||||
|
def post_call_success(self, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
|
||||||
|
return response
|
||||||
|
|
||||||
|
def post_call_failure(self, *args, **kwargs):
|
||||||
|
pass
|
Loading…
Add table
Add a link
Reference in a new issue