From 2bc583c2a61a9b009bbf8b29122efd1f2233b320 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 7 Dec 2023 13:19:17 -0800 Subject: [PATCH] (test) proxy - async custom logger --- litellm/proxy/proxy_server.py | 2 +- .../tests/test_configs/custom_callbacks.py | 101 ++++++++++++++++++ .../test_configs/test_custom_logger.yaml | 9 ++ litellm/tests/test_proxy_custom_logger.py | 83 ++++++++++++++ 4 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 litellm/tests/test_configs/custom_callbacks.py create mode 100644 litellm/tests/test_configs/test_custom_logger.yaml create mode 100644 litellm/tests/test_proxy_custom_logger.py diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7e1c3403f..7c76301bd 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -500,7 +500,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): password=cache_password ) elif key == "callbacks": - litellm.callbacks = [get_instance_fn(value=value)] + litellm.callbacks = [get_instance_fn(value=value, config_file_path=config_file_path)] print_verbose(f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}") elif key == "success_callback": litellm.success_callback = [] diff --git a/litellm/tests/test_configs/custom_callbacks.py b/litellm/tests/test_configs/custom_callbacks.py new file mode 100644 index 000000000..27b968ac9 --- /dev/null +++ b/litellm/tests/test_configs/custom_callbacks.py @@ -0,0 +1,101 @@ +from litellm.integrations.custom_logger import CustomLogger +import inspect +import litellm + +class MyCustomHandler(CustomLogger): + def __init__(self): + self.success: bool = False # type: ignore + self.failure: bool = False # type: ignore + self.async_success: bool = False # type: ignore + self.async_success_embedding: bool = False # type: ignore + self.async_failure: bool = False # type: ignore + self.async_failure_embedding: bool = False # type: ignore + + self.async_completion_kwargs = None # type: ignore + self.async_embedding_kwargs = None # type: ignore + self.async_embedding_response = None # type: ignore + + self.async_completion_kwargs_fail = None # type: ignore + self.async_embedding_kwargs_fail = None # type: ignore + blue_color_code = "\033[94m" + reset_color_code = "\033[0m" + print(f"{blue_color_code}Initialized LiteLLM custom logger") + try: + print(f"Logger Initialized with following methods:") + methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))] + + # Pretty print the methods + for method in methods: + print(f" - {method}") + print(f"{reset_color_code}") + except: + pass + + def log_pre_api_call(self, model, messages, kwargs): + print(f"Pre-API Call") + + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + print(f"Post-API Call") + + def log_stream_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Stream") + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Success") + self.success = True + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Failure") + self.failure = True + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Async success") + self.async_success = True + print("Value of async success: ", self.async_success) + print("\n kwargs: ", kwargs) + if kwargs.get("model") == "text-embedding-ada-002": + self.async_success_embedding = True + self.async_embedding_kwargs = kwargs + self.async_embedding_response = response_obj + self.async_completion_kwargs = kwargs + + model = kwargs.get("model", None) + messages = kwargs.get("messages", None) + user = kwargs.get("user", None) + + # Access litellm_params passed to litellm.completion(), example access `metadata` + litellm_params = kwargs.get("litellm_params", {}) + metadata = litellm_params.get("metadata", {}) # headers passed to LiteLLM proxy, can be found here + + # Calculate cost using litellm.completion_cost() + cost = litellm.completion_cost(completion_response=response_obj) + response = response_obj + # tokens used in response + usage = response_obj["usage"] + + print( + f""" + Model: {model}, + Messages: {messages}, + User: {user}, + Usage: {usage}, + Cost: {cost}, + Response: {response} + Proxy Metadata: {metadata} + """ + ) + return + + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + print(f"On Async Failure") + self.async_failure = True + print("Value of async failure: ", self.async_failure) + print("\n kwargs: ", kwargs) + if kwargs.get("model") == "text-embedding-ada-002": + self.async_failure_embedding = True + self.async_embedding_kwargs_fail = kwargs + + self.async_completion_kwargs_fail = kwargs + +my_custom_logger = MyCustomHandler() \ No newline at end of file diff --git a/litellm/tests/test_configs/test_custom_logger.yaml b/litellm/tests/test_configs/test_custom_logger.yaml new file mode 100644 index 000000000..9673b73ba --- /dev/null +++ b/litellm/tests/test_configs/test_custom_logger.yaml @@ -0,0 +1,9 @@ +model_list: + - model_name: "litellm-test-model" + litellm_params: + model: "gpt-3.5-turbo" + +litellm_settings: + drop_params: True + set_verbose: True + callbacks: custom_callbacks.my_custom_logger \ No newline at end of file diff --git a/litellm/tests/test_proxy_custom_logger.py b/litellm/tests/test_proxy_custom_logger.py new file mode 100644 index 000000000..3341c953f --- /dev/null +++ b/litellm/tests/test_proxy_custom_logger.py @@ -0,0 +1,83 @@ +import sys, os +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os, io + +# this file is to test litellm/proxy + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm import embedding, completion, completion_cost, Timeout +from litellm import RateLimitError +import importlib, inspect + +# test /chat/completion request to the proxy +from fastapi.testclient import TestClient +from fastapi import FastAPI +from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined +filepath = os.path.dirname(os.path.abspath(__file__)) +config_fp = f"{filepath}/test_configs/test_custom_logger.yaml" +python_file_path = f"{filepath}/test_configs/custom_callbacks.py" +save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) +app = FastAPI() +app.include_router(router) # Include your router in the test app +@app.on_event("startup") +async def wrapper_startup_event(): + initialize(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=True, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) + +# Here you create a fixture that will be used by your tests +# Make sure the fixture returns TestClient(app) +@pytest.fixture(autouse=True) +def client(): + with TestClient(app) as client: + yield client + + + +def test_chat_completion(client): + try: + # Your test data + print("initialized proxy") + # import the initialized custom logger + my_custom_logger = importlib.util.spec_from_file_location("my_custom_logger", python_file_path) + print("my_custom_logger", my_custom_logger) + + blue_color_code = "\033[94m" + reset_color_code = "\033[0m" + print(f"{blue_color_code}Initialized LiteLLM custom logger") + try: + print(f"Logger Initialized with following methods:") + methods = [method for method in dir(my_custom_logger) if inspect.ismethod(getattr(my_custom_logger, method))] + + # Pretty print the methods + for method in methods: + print(f" - {method}") + print(f"{reset_color_code}") + except: + pass + + for attribute in dir(my_custom_logger): + print(f"{attribute}: {getattr(my_custom_logger, attribute)}") + test_data = { + "model": "litellm-test-model", + "messages": [ + { + "role": "user", + "content": "hi" + }, + ], + "max_tokens": 10, + } + + + response = client.post("/chat/completions", json=test_data) + print("made request", response.status_code, response.text) + result = response.json() + print(f"Received response: {result}") + except Exception as e: + pytest.fail("LiteLLM Proxy test failed. Exception", e) \ No newline at end of file