feat(lago.py): adding support for usage-based billing with lago

Closes https://github.com/BerriAI/litellm/issues/3639
This commit is contained in:
Krrish Dholakia 2024-05-16 10:54:18 -07:00
parent 00b9f1290e
commit e273e66618
6 changed files with 198 additions and 12 deletions

View file

@ -27,8 +27,8 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = [] success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = [] service_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = [] _custom_logger_compatible_callbacks_literal = Literal["lago"]
_custom_logger_compatible_callbacks: list = ["openmeter"] callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
_langfuse_default_tags: Optional[ _langfuse_default_tags: Optional[
List[ List[
Literal[ Literal[

View file

@ -0,0 +1,153 @@
# What is this?
## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639
import dotenv, os, json
import litellm
import traceback, httpx
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
import uuid
from typing import Optional
def get_utc_datetime():
import datetime as dt
from datetime import datetime
if hasattr(dt, "UTC"):
return datetime.now(dt.UTC) # type: ignore
else:
return datetime.utcnow() # type: ignore
class LagoLogger(CustomLogger):
def __init__(self) -> None:
super().__init__()
self.validate_environment()
self.async_http_handler = AsyncHTTPHandler()
self.sync_http_handler = HTTPHandler()
def validate_environment(self):
"""
Expects
LAGO_API_BASE,
LAGO_API_KEY,
in the environment
"""
missing_keys = []
if os.getenv("LAGO_API_KEY", None) is None:
missing_keys.append("LAGO_API_KEY")
if os.getenv("LAGO_API_BASE", None) is None:
missing_keys.append("LAGO_API_BASE")
if os.getenv("LAGO_API_EVENT_CODE", None) is None:
missing_keys.append("LAGO_API_EVENT_CODE")
if len(missing_keys) > 0:
raise Exception("Missing keys={} in environment.".format(missing_keys))
def _common_logic(self, kwargs: dict, response_obj) -> dict:
call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
dt = get_utc_datetime().isoformat()
cost = kwargs.get("response_cost", None)
model = kwargs.get("model")
usage = {}
if (
isinstance(response_obj, litellm.ModelResponse)
or isinstance(response_obj, litellm.EmbeddingResponse)
) and hasattr(response_obj, "usage"):
usage = {
"prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
"completion_tokens": response_obj["usage"].get("completion_tokens", 0),
"total_tokens": response_obj["usage"].get("total_tokens"),
}
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
org_id = litellm_params["metadata"].get("user_api_key_org_id", None)
if end_user_id is None:
raise Exception("LAGO: user is required")
return {
"event": {
"transaction_id": str(uuid.uuid4()),
"external_customer_id": end_user_id,
"code": os.getenv("LAGO_API_EVENT_CODE"),
"properties": {"model": model, "response_cost": 10000, **usage},
}
}
def log_success_event(self, kwargs, response_obj, start_time, end_time):
_url = os.getenv("LAGO_API_BASE")
assert _url is not None and isinstance(
_url, str
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(_url)
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = os.getenv("LAGO_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
try:
response = self.sync_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except Exception as e:
if hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
_url = os.getenv("LAGO_API_BASE")
assert _url is not None and isinstance(
_url, str
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(
_url
)
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = os.getenv("LAGO_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
except Exception as e:
raise e
response: Optional[httpx.Response] = None
try:
response = await self.async_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except Exception as e:
if response is not None and hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
raise e

View file

@ -30,6 +30,7 @@ router_settings:
litellm_settings: litellm_settings:
fallbacks: [{"gpt-3.5-turbo-012": ["azure-gpt-3.5-turbo"]}] fallbacks: [{"gpt-3.5-turbo-012": ["azure-gpt-3.5-turbo"]}]
callbacks: ["lago"]
# service_callback: ["prometheus_system"] # service_callback: ["prometheus_system"]
# success_callback: ["prometheus"] # success_callback: ["prometheus"]
# failure_callback: ["prometheus"] # failure_callback: ["prometheus"]

View file

@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast import threading, ast
import shutil, random, traceback, requests import shutil, random, traceback, requests
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Optional, List, Callable from typing import Optional, List, Callable, get_args
import secrets, subprocess import secrets, subprocess
import hashlib, uuid import hashlib, uuid
import warnings import warnings
@ -2207,8 +2207,18 @@ class ProxyConfig:
elif key == "callbacks": elif key == "callbacks":
if isinstance(value, list): if isinstance(value, list):
imported_list: List[Any] = [] imported_list: List[Any] = []
known_compatible_callbacks = list(
get_args(
litellm._custom_logger_compatible_callbacks_literal
)
)
for callback in value: # ["presidio", <my-custom-callback>] for callback in value: # ["presidio", <my-custom-callback>]
if isinstance(callback, str) and callback == "presidio": if (
isinstance(callback, str)
and callback in known_compatible_callbacks
):
imported_list.append(callback)
elif isinstance(callback, str) and callback == "presidio":
from litellm.proxy.hooks.presidio_pii_masking import ( from litellm.proxy.hooks.presidio_pii_masking import (
_OPTIONAL_PresidioPIIMasking, _OPTIONAL_PresidioPIIMasking,
) )

View file

@ -140,6 +140,8 @@ class ProxyLogging:
self.slack_alerting_instance.response_taking_too_long_callback self.slack_alerting_instance.response_taking_too_long_callback
) )
for callback in litellm.callbacks: for callback in litellm.callbacks:
if isinstance(callback, str):
callback = litellm.utils._init_custom_logger_compatible_class(callback)
if callback not in litellm.input_callback: if callback not in litellm.input_callback:
litellm.input_callback.append(callback) litellm.input_callback.append(callback)
if callback not in litellm.success_callback: if callback not in litellm.success_callback:

View file

@ -76,6 +76,7 @@ from .integrations.weights_biases import WeightsBiasesLogger
from .integrations.custom_logger import CustomLogger from .integrations.custom_logger import CustomLogger
from .integrations.langfuse import LangFuseLogger from .integrations.langfuse import LangFuseLogger
from .integrations.openmeter import OpenMeterLogger from .integrations.openmeter import OpenMeterLogger
from .integrations.lago import LagoLogger
from .integrations.datadog import DataDogLogger from .integrations.datadog import DataDogLogger
from .integrations.prometheus import PrometheusLogger from .integrations.prometheus import PrometheusLogger
from .integrations.prometheus_services import PrometheusServicesLogger from .integrations.prometheus_services import PrometheusServicesLogger
@ -123,6 +124,7 @@ from typing import (
BinaryIO, BinaryIO,
Iterable, Iterable,
Tuple, Tuple,
Callable,
) )
from .caching import Cache from .caching import Cache
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -147,6 +149,7 @@ weightsBiasesLogger = None
customLogger = None customLogger = None
langFuseLogger = None langFuseLogger = None
openMeterLogger = None openMeterLogger = None
lagoLogger = None
dataDogLogger = None dataDogLogger = None
prometheusLogger = None prometheusLogger = None
dynamoLogger = None dynamoLogger = None
@ -2111,7 +2114,7 @@ class Logging:
""" """
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
""" """
print_verbose(f"Logging Details LiteLLM-Async Success Call: {cache_hit}") print_verbose(f"Logging Details LiteLLM-Async Success Call")
start_time, end_time, result = self._success_handler_helper_fn( start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
) )
@ -2333,8 +2336,8 @@ class Logging:
end_time=end_time, end_time=end_time,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
except: except Exception as e:
print_verbose( verbose_logger.error(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
) )
pass pass
@ -2692,6 +2695,15 @@ class Rules:
return True return True
def _init_custom_logger_compatible_class(
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
) -> Callable:
if logging_integration == "lago":
return LagoLogger() # type: ignore
elif logging_integration == "openmeter":
return OpenMeterLogger() # type: ignore
####### CLIENT ################### ####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def function_setup( def function_setup(
@ -2702,16 +2714,24 @@ def function_setup(
function_id = kwargs["id"] if "id" in kwargs else None function_id = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0: if len(litellm.callbacks) > 0:
for callback in litellm.callbacks: for callback in litellm.callbacks:
# check if callback is a string - e.g. "lago", "openmeter"
if isinstance(callback, str):
callback = _init_custom_logger_compatible_class(callback)
if any(
isinstance(cb, type(callback))
for cb in litellm._async_success_callback
): # don't double add a callback
continue
if callback not in litellm.input_callback: if callback not in litellm.input_callback:
litellm.input_callback.append(callback) litellm.input_callback.append(callback) # type: ignore
if callback not in litellm.success_callback: if callback not in litellm.success_callback:
litellm.success_callback.append(callback) litellm.success_callback.append(callback) # type: ignore
if callback not in litellm.failure_callback: if callback not in litellm.failure_callback:
litellm.failure_callback.append(callback) litellm.failure_callback.append(callback) # type: ignore
if callback not in litellm._async_success_callback: if callback not in litellm._async_success_callback:
litellm._async_success_callback.append(callback) litellm._async_success_callback.append(callback) # type: ignore
if callback not in litellm._async_failure_callback: if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback) litellm._async_failure_callback.append(callback) # type: ignore
print_verbose( print_verbose(
f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}" f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}"
) )