From b3f54020178aa8ef288eb9814aec4ae81d7d71fb Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 15 Feb 2024 13:44:07 -0800 Subject: [PATCH] (feat) custom API callbacks --- enterprise/callbacks/api_callback.py | 15 +++++++++--- enterprise/callbacks/example_logging_api.py | 27 +++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) create mode 100644 enterprise/callbacks/example_logging_api.py diff --git a/enterprise/callbacks/api_callback.py b/enterprise/callbacks/api_callback.py index 4a1c27d41..8c32ee86c 100644 --- a/enterprise/callbacks/api_callback.py +++ b/enterprise/callbacks/api_callback.py @@ -29,9 +29,11 @@ from litellm._logging import print_verbose, verbose_logger class GenericAPILogger: # Class variables or attributes - def __init__(self, endpoint=None): + def __init__(self, endpoint=None, headers=None): try: verbose_logger.debug(f"in init GenericAPILogger, endpoint {endpoint}") + self.endpoint = endpoint + self.headers = headers pass @@ -44,7 +46,7 @@ class GenericAPILogger: def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): try: verbose_logger.debug( - f"s3 Logging - Enters logging function for model {kwargs}" + f"GenericAPILogger Logging - Enters logging function for model {kwargs}" ) # construct payload to send custom logger @@ -54,6 +56,7 @@ class GenericAPILogger: litellm_params.get("metadata", {}) or {} ) # if litellm_params['metadata'] == None messages = kwargs.get("messages") + cost = kwargs.get("response_cost", 0.0) optional_params = kwargs.get("optional_params", {}) call_type = kwargs.get("call_type", "litellm.completion") cache_hit = kwargs.get("cache_hit", False) @@ -74,6 +77,7 @@ class GenericAPILogger: "response": response_obj, "usage": usage, "metadata": metadata, + "cost": cost, } # Ensure everything in the payload is converted to str @@ -87,11 +91,14 @@ class GenericAPILogger: import json payload = json.dumps(payload) + data = { + "data": payload, + } - print_verbose(f"\nGeneric Logger - Logging payload = {payload}") + print_verbose(f"\nGeneric Logger - Logging payload = {data}") # make request to endpoint with payload - response = requests.post(self.endpoint, data=payload, headers=self.headers) + response = requests.post(self.endpoint, data=data, headers=self.headers) response_status = response.status_code response_text = response.text diff --git a/enterprise/callbacks/example_logging_api.py b/enterprise/callbacks/example_logging_api.py new file mode 100644 index 000000000..f3c16299a --- /dev/null +++ b/enterprise/callbacks/example_logging_api.py @@ -0,0 +1,27 @@ +# this is an example endpoint to receive data from litellm +from fastapi import FastAPI, HTTPException, Request + +app = FastAPI() + + +@app.post("/log-event") +async def log_event(request: Request): + try: + # Assuming the incoming request has JSON data + data = await request.json() + print("Received request data:") + print(data) + + # Your additional logic can go here + # For now, just printing the received data + + return {"message": "Request received successfully"} + except Exception as e: + print(f"Error processing request: {str(e)}") + raise HTTPException(status_code=500, detail="Internal Server Error") + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="127.0.0.1", port=8000)