From 65e00b438ea348560356fee47f6f0b3edcaa1667 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Dec 2023 09:09:54 -0800 Subject: [PATCH] (feat) proxy-read litellm custom callback class --- litellm/proxy/custom_logger.py | 15 ++++++++++++++- litellm/proxy/proxy_server.py | 19 +++++++++++++++++++ pyproject.toml | 1 + 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/custom_logger.py b/litellm/proxy/custom_logger.py index 8a1a824ac..d30722bd9 100644 --- a/litellm/proxy/custom_logger.py +++ b/litellm/proxy/custom_logger.py @@ -1,4 +1,5 @@ from litellm.integrations.custom_logger import CustomLogger +import litellm class MyCustomHandler(CustomLogger): def log_pre_api_call(self, model, messages, kwargs): print(f"Pre-API Call") @@ -6,6 +7,16 @@ class MyCustomHandler(CustomLogger): def log_post_api_call(self, kwargs, response_obj, start_time, end_time): # log: key, user, model, prompt, response, tokens, cost print(f"Post-API Call") + print("\n kwargs\n") + print(kwargs) + model = kwargs["model"] + messages = kwargs["messages"] + cost = litellm.completion_cost(completion_response=response_obj) + + # tokens used in response + usage = response_obj.usage + print(usage) + def log_stream_event(self, kwargs, response_obj, start_time, end_time): print(f"On Stream") @@ -16,4 +27,6 @@ class MyCustomHandler(CustomLogger): def log_failure_event(self, kwargs, response_obj, start_time, end_time): print(f"On Failure") -customHandler = MyCustomHandler() +proxy_handler_instance = MyCustomHandler() + +# need to set litellm.callbacks = [customHandler] # on the proxy diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a2605dd55..a0e9250ac 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -6,6 +6,7 @@ from typing import Optional, List import secrets, subprocess import hashlib, uuid import warnings +import importlib messages: list = [] sys.path.insert( 0, os.path.abspath("../..") @@ -556,6 +557,24 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): port=cache_port, password=cache_password ) + elif key == "callbacks": + print(f"{blue_color_code}\nSetting custom callbacks on Proxy") + print() + passed_module, instance_name = value.split(".") + + # Dynamically import the module + module = importlib.import_module(passed_module) + # Get the instance from the module + instance = getattr(module, instance_name) + + methods = [method for method in dir(instance) if callable(getattr(instance, method))] + # Print the methods + print("Methods in the instance:") + for method in methods: + print(method) + + litellm.callbacks = [instance] + else: setattr(litellm, key, value) diff --git a/pyproject.toml b/pyproject.toml index 37f87fa45..2befc1383 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ proxy = [ "backoff", "rq", "orjson", + "importlib", ] extra_proxy = [