fix(proxy_server.py): add call hooks pre+post completion and embedding calls

This commit is contained in:
Krrish Dholakia 2023-12-07 20:35:32 -08:00
parent dfba305508
commit f5afc429b3
2 changed files with 59 additions and 4 deletions

View file

@ -1,4 +1,4 @@
from typing import Optional, List, Any
from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib
### DB CONNECTOR ###
@ -131,4 +131,33 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
# Re-raise the exception with a user-friendly message
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
except Exception as e:
raise e
raise e
### CALL HOOKS ###
class CallHooks:
"""
Allows users to modify the incoming request / output to the proxy, without having to deal with parsing Request body.
Covers:
1. /chat/completions
2. /embeddings
"""
def __init__(self, *args, **kwargs):
self.call_details = {}
def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list):
self.call_details["litellm_settings"] = litellm_settings
self.call_details["general_settings"] = general_settings
self.call_details["model_list"] = model_list
def pre_call(self, data: dict, call_type: Literal["completion", "embeddings"]):
self.call_details["data"] = data
return data
def post_call_success(self, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
return response
def post_call_failure(self, *args, **kwargs):
pass