From 665939bc487a36df21e4ba858892ac6239b78bab Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 7 Dec 2023 20:35:32 -0800 Subject: [PATCH] fix(proxy_server.py): add call hooks pre+post completion and embedding calls --- litellm/proxy/proxy_server.py | 30 ++++++++++++++++++++++++++++-- litellm/proxy/utils.py | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index bdd415f28d..5e50487772 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -93,7 +93,8 @@ def generate_feedback_box(): import litellm from litellm.proxy.utils import ( PrismaClient, - get_instance_fn + get_instance_fn, + CallHooks ) import pydantic from litellm.proxy._types import * @@ -196,6 +197,7 @@ user_custom_auth = None use_background_health_checks = None health_check_interval = None health_check_results = {} +call_hooks = CallHooks() ### REDIS QUEUE ### async_result = None celery_app_conn = None @@ -570,6 +572,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): litellm_model_name = model["litellm_params"]["model"] if "ollama" in litellm_model_name: run_ollama_serve() + + call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings) return router, model_list, general_settings async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None): @@ -740,9 +744,13 @@ def data_generator(response): yield f"data: {json.dumps(chunk)}\n\n" async def async_data_generator(response): + global call_hooks + print_verbose("inside generator") async for chunk in response: print_verbose(f"returned chunk: {chunk}") + ### CALL HOOKS ### - modify outgoing response + response = call_hooks.post_call_success(chunk=chunk, call_type="completion") try: yield f"data: {json.dumps(chunk.dict())}\n\n" except: @@ -941,7 +949,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): - global general_settings, user_debug + global general_settings, user_debug, call_hooks try: data = {} data = await request.json() # type: ignore @@ -977,6 +985,11 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap data["max_tokens"] = user_max_tokens if user_api_base: data["api_base"] = user_api_base + + ### CALL HOOKS ### - modify incoming data before calling the model + data = call_hooks.pre_call(data=data, call_type="completion") + + ### ROUTE THE REQUEST ### router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] if llm_router is not None and data["model"] in router_model_names: # model in router model list response = await llm_router.acompletion(**data) @@ -986,8 +999,13 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap response = await llm_router.acompletion(**data) else: # router is not set response = await litellm.acompletion(**data) + if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses return StreamingResponse(async_data_generator(response), media_type='text/event-stream') + + ### CALL HOOKS ### - modify outgoing response + response = call_hooks.post_call_success(response=response, call_type="completion") + background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL return response except Exception as e: @@ -1022,6 +1040,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): try: + global call_hooks # Use orjson to parse JSON data, orjson speeds up requests significantly body = await request.body() @@ -1056,6 +1075,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen input_list.append(litellm.decode(model="gpt-3.5-turbo", tokens=i)) data["input"] = input_list break + + ### CALL HOOKS ### - modify incoming data before calling the model + data = call_hooks.pre_call(data=data, call_type="embeddings") ## ROUTE TO CORRECT ENDPOINT ## if llm_router is not None and data["model"] in router_model_names: # model in router model list @@ -1065,6 +1087,10 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen else: response = await litellm.aembedding(**data) background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL + + ### CALL HOOKS ### - modify outgoing response + data = call_hooks.post_call_success(response=response, call_type="embeddings") + return response except Exception as e: traceback.print_exc() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 9e4ec5900b..41fffb1137 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Any +from typing import Optional, List, Any, Literal import os, subprocess, hashlib, importlib ### DB CONNECTOR ### @@ -131,4 +131,33 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: # Re-raise the exception with a user-friendly message raise ImportError(f"Could not import {instance_name} from {module_name}") from e except Exception as e: - raise e \ No newline at end of file + raise e + +### CALL HOOKS ### + +class CallHooks: + """ + Allows users to modify the incoming request / output to the proxy, without having to deal with parsing Request body. + + Covers: + 1. /chat/completions + 2. /embeddings + """ + + def __init__(self, *args, **kwargs): + self.call_details = {} + + def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list): + self.call_details["litellm_settings"] = litellm_settings + self.call_details["general_settings"] = general_settings + self.call_details["model_list"] = model_list + + def pre_call(self, data: dict, call_type: Literal["completion", "embeddings"]): + self.call_details["data"] = data + return data + + def post_call_success(self, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None): + return response + + def post_call_failure(self, *args, **kwargs): + pass \ No newline at end of file