feat(openmeter.py): add support for user billing

open-meter supports user based billing. Closes https://github.com/BerriAI/litellm/issues/1268
This commit is contained in:
Krrish Dholakia 2024-05-01 17:23:48 -07:00
parent 37eb7910d2
commit 2a9651b3ca
5 changed files with 210 additions and 13 deletions

View file

@ -22,6 +22,7 @@ success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = [] service_callback: List[Union[str, Callable]] = []
callbacks: List[Callable] = [] callbacks: List[Callable] = []
_custom_logger_compatible_callbacks: list = ["openmeter"]
_langfuse_default_tags: Optional[ _langfuse_default_tags: Optional[
List[ List[
Literal[ Literal[

View file

@ -0,0 +1,122 @@
# What is this?
## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268
import dotenv, os, json
import requests
import litellm
dotenv.load_dotenv() # Loading env variables using dotenv
import traceback
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
import uuid
def get_utc_datetime():
import datetime as dt
from datetime import datetime
if hasattr(dt, "UTC"):
return datetime.now(dt.UTC) # type: ignore
else:
return datetime.utcnow() # type: ignore
class OpenMeterLogger(CustomLogger):
def __init__(self) -> None:
super().__init__()
self.validate_environment()
self.async_http_handler = AsyncHTTPHandler()
self.sync_http_handler = HTTPHandler()
def validate_environment(self):
"""
Expects
OPENMETER_API_ENDPOINT,
OPENMETER_API_KEY,
in the environment
"""
missing_keys = []
if litellm.get_secret("OPENMETER_API_ENDPOINT", None) is None:
missing_keys.append("OPENMETER_API_ENDPOINT")
if litellm.get_secret("OPENMETER_API_KEY", None) is None:
missing_keys.append("OPENMETER_API_KEY")
if len(missing_keys) > 0:
raise Exception("Missing keys={} in environment.".format(missing_keys))
def _common_logic(self, kwargs: dict, response_obj):
call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
dt = get_utc_datetime().isoformat()
cost = kwargs.get("response_cost", None)
model = kwargs.get("model")
usage = {}
if (
isinstance(response_obj, litellm.ModelResponse)
or isinstance(response_obj, litellm.EmbeddingResponse)
) and hasattr(response_obj, "usage"):
usage = {
"prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
"completion_tokens": response_obj["usage"].get("completion_tokens", 0),
"total_tokens": response_obj["usage"].get("total_tokens"),
}
return {
"specversion": "1.0",
"type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"),
"id": call_id,
"time": dt,
"subject": kwargs.get("user", ""), # end-user passed in via 'user' param
"source": "litellm-proxy",
"data": {"model": model, "cost": cost, **usage},
}
def log_success_event(self, kwargs, response_obj, start_time, end_time):
_url = litellm.get_secret("OPENMETER_API_ENDPOINT")
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = litellm.get_secret("OPENMETER_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
self.sync_http_handler.post(
url=_url,
data=_data,
headers={
"Content-Type": "application/cloudevents+json",
"Authorization": "Bearer {}".format(api_key),
},
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
_url = litellm.get_secret("OPENMETER_API_ENDPOINT")
if _url.endswith("/"):
_url += "api/v1/events"
else:
_url += "/api/v1/events"
api_key = litellm.get_secret("OPENMETER_API_KEY")
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
_headers = {
"Content-Type": "application/cloudevents+json",
"Authorization": "Bearer {}".format(api_key),
}
try:
response = await self.async_http_handler.post(
url=_url,
data=json.dumps(_data),
headers=_headers,
)
response.raise_for_status()
except Exception as e:
print(f"\nAn Exception Occurred - {str(e)}")
if hasattr(response, "text"):
print(f"\nError Message: {response.text}")
raise e

View file

@ -1,15 +1,8 @@
model_list: model_list:
- litellm_params: - litellm_params:
api_base: http://0.0.0.0:8080 api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key api_key: my-fake-key
model: openai/my-fake-model model: openai/my-fake-model
rpm: 100
model_name: fake-openai-endpoint
- litellm_params:
api_base: http://0.0.0.0:8081
api_key: my-fake-key
model: openai/my-fake-model-2
rpm: 100
model_name: fake-openai-endpoint model_name: fake-openai-endpoint
router_settings: router_settings:
num_retries: 0 num_retries: 0
@ -17,3 +10,6 @@ router_settings:
redis_host: os.environ/REDIS_HOST redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT redis_port: os.environ/REDIS_PORT
litellm_settings:
success_callback: ["openmeter"]

View file

@ -1777,7 +1777,7 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
usage = response_obj["usage"] usage = response_obj["usage"]
if type(usage) == litellm.Usage: if type(usage) == litellm.Usage:
usage = dict(usage) usage = dict(usage)
id = response_obj.get("id", str(uuid.uuid4())) id = response_obj.get("id", kwargs.get("litellm_call_id"))
api_key = metadata.get("user_api_key", "") api_key = metadata.get("user_api_key", "")
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"): if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
# hash the api_key # hash the api_key

View file

@ -70,6 +70,7 @@ from .integrations.langsmith import LangsmithLogger
from .integrations.weights_biases import WeightsBiasesLogger from .integrations.weights_biases import WeightsBiasesLogger
from .integrations.custom_logger import CustomLogger from .integrations.custom_logger import CustomLogger
from .integrations.langfuse import LangFuseLogger from .integrations.langfuse import LangFuseLogger
from .integrations.openmeter import OpenMeterLogger
from .integrations.datadog import DataDogLogger from .integrations.datadog import DataDogLogger
from .integrations.prometheus import PrometheusLogger from .integrations.prometheus import PrometheusLogger
from .integrations.prometheus_services import PrometheusServicesLogger from .integrations.prometheus_services import PrometheusServicesLogger
@ -130,6 +131,7 @@ langsmithLogger = None
weightsBiasesLogger = None weightsBiasesLogger = None
customLogger = None customLogger = None
langFuseLogger = None langFuseLogger = None
openMeterLogger = None
dataDogLogger = None dataDogLogger = None
prometheusLogger = None prometheusLogger = None
dynamoLogger = None dynamoLogger = None
@ -1922,6 +1924,51 @@ class Logging:
end_time=end_time, end_time=end_time,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if (
callback == "openmeter"
and self.model_call_details.get("litellm_params", {}).get(
"acompletion", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"aembedding", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False
)
== False
):
global openMeterLogger
if openMeterLogger is None:
print_verbose("Instantiates openmeter client")
openMeterLogger = OpenMeterLogger()
if self.stream and complete_streaming_response is None:
openMeterLogger.log_stream_event(
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
)
else:
if self.stream and complete_streaming_response:
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
)
result = self.model_call_details["complete_response"]
openMeterLogger.log_success_event(
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
)
if ( if (
isinstance(callback, CustomLogger) isinstance(callback, CustomLogger)
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
@ -2121,6 +2168,35 @@ class Logging:
await litellm.cache.async_add_cache(result, **kwargs) await litellm.cache.async_add_cache(result, **kwargs)
else: else:
litellm.cache.add_cache(result, **kwargs) litellm.cache.add_cache(result, **kwargs)
if callback == "openmeter":
global openMeterLogger
if self.stream == True:
if (
"async_complete_streaming_response"
in self.model_call_details
):
await openMeterLogger.async_log_success_event(
kwargs=self.model_call_details,
response_obj=self.model_call_details[
"async_complete_streaming_response"
],
start_time=start_time,
end_time=end_time,
)
else:
await openMeterLogger.async_log_stream_event( # [TODO]: move this to being an async log stream event function
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
)
else:
await openMeterLogger.async_log_success_event(
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
)
if isinstance(callback, CustomLogger): # custom logger class if isinstance(callback, CustomLogger): # custom logger class
if self.stream == True: if self.stream == True:
if ( if (
@ -2594,7 +2670,7 @@ def function_setup(
if inspect.iscoroutinefunction(callback): if inspect.iscoroutinefunction(callback):
litellm._async_success_callback.append(callback) litellm._async_success_callback.append(callback)
removed_async_items.append(index) removed_async_items.append(index)
elif callback == "dynamodb": elif callback == "dynamodb" or callback == "openmeter":
# dynamo is an async callback, it's used for the proxy and needs to be async # dynamo is an async callback, it's used for the proxy and needs to be async
# we only support async dynamo db logging for acompletion/aembedding since that's used on proxy # we only support async dynamo db logging for acompletion/aembedding since that's used on proxy
litellm._async_success_callback.append(callback) litellm._async_success_callback.append(callback)
@ -6777,11 +6853,11 @@ def validate_environment(model: Optional[str] = None) -> dict:
def set_callbacks(callback_list, function_id=None): def set_callbacks(callback_list, function_id=None):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
try: try:
for callback in callback_list: for callback in callback_list:
print_verbose(f"callback: {callback}") print_verbose(f"init callback list: {callback}")
if callback == "sentry": if callback == "sentry":
try: try:
import sentry_sdk import sentry_sdk
@ -6844,6 +6920,8 @@ def set_callbacks(callback_list, function_id=None):
promptLayerLogger = PromptLayerLogger() promptLayerLogger = PromptLayerLogger()
elif callback == "langfuse": elif callback == "langfuse":
langFuseLogger = LangFuseLogger() langFuseLogger = LangFuseLogger()
elif callback == "openmeter":
openMeterLogger = OpenMeterLogger()
elif callback == "datadog": elif callback == "datadog":
dataDogLogger = DataDogLogger() dataDogLogger = DataDogLogger()
elif callback == "prometheus": elif callback == "prometheus":