feat(lago.py): adding support for usage-based billing with lago

Closes https://github.com/BerriAI/litellm/issues/3639
This commit is contained in:
Krrish Dholakia 2024-05-16 10:54:18 -07:00
parent 00b9f1290e
commit e273e66618
6 changed files with 198 additions and 12 deletions

View file

@ -27,8 +27,8 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = []
_custom_logger_compatible_callbacks: list = ["openmeter"]
_custom_logger_compatible_callbacks_literal = Literal["lago"]
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
_langfuse_default_tags: Optional[
List[
Literal[

View file

@ -0,0 +1,153 @@
# What is this?
## On Success events log cost to Lago - https://github.com/BerriAI/litellm/issues/3639
import dotenv, os, json
import litellm
import traceback, httpx
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
import uuid
from typing import Optional
def get_utc_datetime():
import datetime as dt
from datetime import datetime
if hasattr(dt, "UTC"):
return datetime.now(dt.UTC) # type: ignore
else:
return datetime.utcnow() # type: ignore
class LagoLogger(CustomLogger):
def __init__(self) -> None:
super().__init__()
self.validate_environment()
self.async_http_handler = AsyncHTTPHandler()
self.sync_http_handler = HTTPHandler()
def validate_environment(self):
"""
Expects
LAGO_API_BASE,
LAGO_API_KEY,
in the environment
"""
missing_keys = []
if os.getenv("LAGO_API_KEY", None) is None:
missing_keys.append("LAGO_API_KEY")
if os.getenv("LAGO_API_BASE", None) is None:
missing_keys.append("LAGO_API_BASE")
if os.getenv("LAGO_API_EVENT_CODE", None) is None:
missing_keys.append("LAGO_API_EVENT_CODE")
if len(missing_keys) > 0:
raise Exception("Missing keys={} in environment.".format(missing_keys))
def _common_logic(self, kwargs: dict, response_obj) -> dict:
call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
dt = get_utc_datetime().isoformat()
cost = kwargs.get("response_cost", None)
model = kwargs.get("model")
usage = {}
if (
isinstance(response_obj, litellm.ModelResponse)
or isinstance(response_obj, litellm.EmbeddingResponse)
) and hasattr(response_obj, "usage"):
usage = {
"prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
"completion_tokens": response_obj["usage"].get("completion_tokens", 0),
"total_tokens": response_obj["usage"].get("total_tokens"),
}
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
org_id = litellm_params["metadata"].get("user_api_key_org_id", None)
if end_user_id is None:
raise Exception("LAGO: user is required")
return {
"event": {
"transaction_id": str(uuid.uuid4()),
"external_customer_id": end_user_id,
"code": os.getenv("LAGO_API_EVENT_CODE"),
"properties": {"model": model, "response_cost": 10000, **usage},
}
}
def log_success_event(self, kwargs, response_obj, start_time, end_time):
_url = os.getenv("LAGO_API_BASE")
assert _url is not None and isinstance(
_url, str
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(_url)
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = os.getenv("LAGO_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
try:
response = self.sync_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except Exception as e:
if hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
_url = os.getenv("LAGO_API_BASE")
assert _url is not None and isinstance(
_url, str
), "LAGO_API_BASE missing or not set correctly. LAGO_API_BASE={}".format(
_url
)
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = os.getenv("LAGO_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
except Exception as e:
raise e
response: Optional[httpx.Response] = None
try:
response = await self.async_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except Exception as e:
if response is not None and hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
raise e

View file

@ -30,6 +30,7 @@ router_settings:
litellm_settings:
fallbacks: [{"gpt-3.5-turbo-012": ["azure-gpt-3.5-turbo"]}]
callbacks: ["lago"]
# service_callback: ["prometheus_system"]
# success_callback: ["prometheus"]
# failure_callback: ["prometheus"]

View file

@ -2,7 +2,7 @@ import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast
import shutil, random, traceback, requests
from datetime import datetime, timedelta, timezone
from typing import Optional, List, Callable
from typing import Optional, List, Callable, get_args
import secrets, subprocess
import hashlib, uuid
import warnings
@ -2207,8 +2207,18 @@ class ProxyConfig:
elif key == "callbacks":
if isinstance(value, list):
imported_list: List[Any] = []
known_compatible_callbacks = list(
get_args(
litellm._custom_logger_compatible_callbacks_literal
)
)
for callback in value: # ["presidio", <my-custom-callback>]
if isinstance(callback, str) and callback == "presidio":
if (
isinstance(callback, str)
and callback in known_compatible_callbacks
):
imported_list.append(callback)
elif isinstance(callback, str) and callback == "presidio":
from litellm.proxy.hooks.presidio_pii_masking import (
_OPTIONAL_PresidioPIIMasking,
)

View file

@ -140,6 +140,8 @@ class ProxyLogging:
self.slack_alerting_instance.response_taking_too_long_callback
)
for callback in litellm.callbacks:
if isinstance(callback, str):
callback = litellm.utils._init_custom_logger_compatible_class(callback)
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
if callback not in litellm.success_callback:

View file

@ -76,6 +76,7 @@ from .integrations.weights_biases import WeightsBiasesLogger
from .integrations.custom_logger import CustomLogger
from .integrations.langfuse import LangFuseLogger
from .integrations.openmeter import OpenMeterLogger
from .integrations.lago import LagoLogger
from .integrations.datadog import DataDogLogger
from .integrations.prometheus import PrometheusLogger
from .integrations.prometheus_services import PrometheusServicesLogger
@ -123,6 +124,7 @@ from typing import (
BinaryIO,
Iterable,
Tuple,
Callable,
)
from .caching import Cache
from concurrent.futures import ThreadPoolExecutor
@ -147,6 +149,7 @@ weightsBiasesLogger = None
customLogger = None
langFuseLogger = None
openMeterLogger = None
lagoLogger = None
dataDogLogger = None
prometheusLogger = None
dynamoLogger = None
@ -2111,7 +2114,7 @@ class Logging:
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
print_verbose(f"Logging Details LiteLLM-Async Success Call: {cache_hit}")
print_verbose(f"Logging Details LiteLLM-Async Success Call")
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
)
@ -2333,8 +2336,8 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
except:
print_verbose(
except Exception as e:
verbose_logger.error(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
)
pass
@ -2692,6 +2695,15 @@ class Rules:
return True
def _init_custom_logger_compatible_class(
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
) -> Callable:
if logging_integration == "lago":
return LagoLogger() # type: ignore
elif logging_integration == "openmeter":
return OpenMeterLogger() # type: ignore
####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def function_setup(
@ -2702,16 +2714,24 @@ def function_setup(
function_id = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
# check if callback is a string - e.g. "lago", "openmeter"
if isinstance(callback, str):
callback = _init_custom_logger_compatible_class(callback)
if any(
isinstance(cb, type(callback))
for cb in litellm._async_success_callback
): # don't double add a callback
continue
if callback not in litellm.input_callback:
litellm.input_callback.append(callback)
litellm.input_callback.append(callback) # type: ignore
if callback not in litellm.success_callback:
litellm.success_callback.append(callback)
litellm.success_callback.append(callback) # type: ignore
if callback not in litellm.failure_callback:
litellm.failure_callback.append(callback)
litellm.failure_callback.append(callback) # type: ignore
if callback not in litellm._async_success_callback:
litellm._async_success_callback.append(callback)
litellm._async_success_callback.append(callback) # type: ignore
if callback not in litellm._async_failure_callback:
litellm._async_failure_callback.append(callback)
litellm._async_failure_callback.append(callback) # type: ignore
print_verbose(
f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}"
)