From 2a9651b3ca3f4e65aee5d27e418cd04eb8a7454a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 1 May 2024 17:23:48 -0700 Subject: [PATCH] feat(openmeter.py): add support for user billing open-meter supports user based billing. Closes https://github.com/BerriAI/litellm/issues/1268 --- litellm/__init__.py | 1 + litellm/integrations/openmeter.py | 122 ++++++++++++++++++++++++ litellm/proxy/_super_secret_config.yaml | 14 +-- litellm/proxy/utils.py | 2 +- litellm/utils.py | 84 +++++++++++++++- 5 files changed, 210 insertions(+), 13 deletions(-) create mode 100644 litellm/integrations/openmeter.py diff --git a/litellm/__init__.py b/litellm/__init__.py index a3d61bce1..5cc4d2316 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -22,6 +22,7 @@ success_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = [] service_callback: List[Union[str, Callable]] = [] callbacks: List[Callable] = [] +_custom_logger_compatible_callbacks: list = ["openmeter"] _langfuse_default_tags: Optional[ List[ Literal[ diff --git a/litellm/integrations/openmeter.py b/litellm/integrations/openmeter.py new file mode 100644 index 000000000..88c27b694 --- /dev/null +++ b/litellm/integrations/openmeter.py @@ -0,0 +1,122 @@ +# What is this? +## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268 + +import dotenv, os, json +import requests +import litellm + +dotenv.load_dotenv() # Loading env variables using dotenv +import traceback +from litellm.integrations.custom_logger import CustomLogger +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +import uuid + + +def get_utc_datetime(): + import datetime as dt + from datetime import datetime + + if hasattr(dt, "UTC"): + return datetime.now(dt.UTC) # type: ignore + else: + return datetime.utcnow() # type: ignore + + +class OpenMeterLogger(CustomLogger): + def __init__(self) -> None: + super().__init__() + self.validate_environment() + self.async_http_handler = AsyncHTTPHandler() + self.sync_http_handler = HTTPHandler() + + def validate_environment(self): + """ + Expects + OPENMETER_API_ENDPOINT, + OPENMETER_API_KEY, + + in the environment + """ + missing_keys = [] + if litellm.get_secret("OPENMETER_API_ENDPOINT", None) is None: + missing_keys.append("OPENMETER_API_ENDPOINT") + + if litellm.get_secret("OPENMETER_API_KEY", None) is None: + missing_keys.append("OPENMETER_API_KEY") + + if len(missing_keys) > 0: + raise Exception("Missing keys={} in environment.".format(missing_keys)) + + def _common_logic(self, kwargs: dict, response_obj): + call_id = response_obj.get("id", kwargs.get("litellm_call_id")) + dt = get_utc_datetime().isoformat() + cost = kwargs.get("response_cost", None) + model = kwargs.get("model") + usage = {} + if ( + isinstance(response_obj, litellm.ModelResponse) + or isinstance(response_obj, litellm.EmbeddingResponse) + ) and hasattr(response_obj, "usage"): + usage = { + "prompt_tokens": response_obj["usage"].get("prompt_tokens", 0), + "completion_tokens": response_obj["usage"].get("completion_tokens", 0), + "total_tokens": response_obj["usage"].get("total_tokens"), + } + + return { + "specversion": "1.0", + "type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"), + "id": call_id, + "time": dt, + "subject": kwargs.get("user", ""), # end-user passed in via 'user' param + "source": "litellm-proxy", + "data": {"model": model, "cost": cost, **usage}, + } + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + _url = litellm.get_secret("OPENMETER_API_ENDPOINT") + if _url.endswith("/"): + _url += "api/v1/events" + else: + _url += "/api/v1/events" + + api_key = litellm.get_secret("OPENMETER_API_KEY") + + _data = self._common_logic(kwargs=kwargs, response_obj=response_obj) + self.sync_http_handler.post( + url=_url, + data=_data, + headers={ + "Content-Type": "application/cloudevents+json", + "Authorization": "Bearer {}".format(api_key), + }, + ) + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + _url = litellm.get_secret("OPENMETER_API_ENDPOINT") + if _url.endswith("/"): + _url += "api/v1/events" + else: + _url += "/api/v1/events" + + api_key = litellm.get_secret("OPENMETER_API_KEY") + + _data = self._common_logic(kwargs=kwargs, response_obj=response_obj) + _headers = { + "Content-Type": "application/cloudevents+json", + "Authorization": "Bearer {}".format(api_key), + } + + try: + response = await self.async_http_handler.post( + url=_url, + data=json.dumps(_data), + headers=_headers, + ) + + response.raise_for_status() + except Exception as e: + print(f"\nAn Exception Occurred - {str(e)}") + if hasattr(response, "text"): + print(f"\nError Message: {response.text}") + raise e diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 9db128d0e..9f2f6ec17 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -1,19 +1,15 @@ model_list: - litellm_params: - api_base: http://0.0.0.0:8080 + api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ api_key: my-fake-key model: openai/my-fake-model - rpm: 100 - model_name: fake-openai-endpoint -- litellm_params: - api_base: http://0.0.0.0:8081 - api_key: my-fake-key - model: openai/my-fake-model-2 - rpm: 100 model_name: fake-openai-endpoint router_settings: num_retries: 0 enable_pre_call_checks: true redis_host: os.environ/REDIS_HOST redis_password: os.environ/REDIS_PASSWORD - redis_port: os.environ/REDIS_PORT \ No newline at end of file + redis_port: os.environ/REDIS_PORT + +litellm_settings: + success_callback: ["openmeter"] \ No newline at end of file diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 97f679fd7..b5db81b31 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1777,7 +1777,7 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time): usage = response_obj["usage"] if type(usage) == litellm.Usage: usage = dict(usage) - id = response_obj.get("id", str(uuid.uuid4())) + id = response_obj.get("id", kwargs.get("litellm_call_id")) api_key = metadata.get("user_api_key", "") if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"): # hash the api_key diff --git a/litellm/utils.py b/litellm/utils.py index 6243195ef..81e6cc660 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -70,6 +70,7 @@ from .integrations.langsmith import LangsmithLogger from .integrations.weights_biases import WeightsBiasesLogger from .integrations.custom_logger import CustomLogger from .integrations.langfuse import LangFuseLogger +from .integrations.openmeter import OpenMeterLogger from .integrations.datadog import DataDogLogger from .integrations.prometheus import PrometheusLogger from .integrations.prometheus_services import PrometheusServicesLogger @@ -130,6 +131,7 @@ langsmithLogger = None weightsBiasesLogger = None customLogger = None langFuseLogger = None +openMeterLogger = None dataDogLogger = None prometheusLogger = None dynamoLogger = None @@ -1922,6 +1924,51 @@ class Logging: end_time=end_time, print_verbose=print_verbose, ) + if ( + callback == "openmeter" + and self.model_call_details.get("litellm_params", {}).get( + "acompletion", False + ) + == False + and self.model_call_details.get("litellm_params", {}).get( + "aembedding", False + ) + == False + and self.model_call_details.get("litellm_params", {}).get( + "aimage_generation", False + ) + == False + and self.model_call_details.get("litellm_params", {}).get( + "atranscription", False + ) + == False + ): + global openMeterLogger + if openMeterLogger is None: + print_verbose("Instantiates openmeter client") + openMeterLogger = OpenMeterLogger() + if self.stream and complete_streaming_response is None: + openMeterLogger.log_stream_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) + else: + if self.stream and complete_streaming_response: + self.model_call_details["complete_response"] = ( + self.model_call_details.get( + "complete_streaming_response", {} + ) + ) + result = self.model_call_details["complete_response"] + openMeterLogger.log_success_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) + if ( isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get( @@ -2121,6 +2168,35 @@ class Logging: await litellm.cache.async_add_cache(result, **kwargs) else: litellm.cache.add_cache(result, **kwargs) + if callback == "openmeter": + global openMeterLogger + if self.stream == True: + if ( + "async_complete_streaming_response" + in self.model_call_details + ): + await openMeterLogger.async_log_success_event( + kwargs=self.model_call_details, + response_obj=self.model_call_details[ + "async_complete_streaming_response" + ], + start_time=start_time, + end_time=end_time, + ) + else: + await openMeterLogger.async_log_stream_event( # [TODO]: move this to being an async log stream event function + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) + else: + await openMeterLogger.async_log_success_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) if isinstance(callback, CustomLogger): # custom logger class if self.stream == True: if ( @@ -2594,7 +2670,7 @@ def function_setup( if inspect.iscoroutinefunction(callback): litellm._async_success_callback.append(callback) removed_async_items.append(index) - elif callback == "dynamodb": + elif callback == "dynamodb" or callback == "openmeter": # dynamo is an async callback, it's used for the proxy and needs to be async # we only support async dynamo db logging for acompletion/aembedding since that's used on proxy litellm._async_success_callback.append(callback) @@ -6777,11 +6853,11 @@ def validate_environment(model: Optional[str] = None) -> dict: def set_callbacks(callback_list, function_id=None): - global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger + global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger try: for callback in callback_list: - print_verbose(f"callback: {callback}") + print_verbose(f"init callback list: {callback}") if callback == "sentry": try: import sentry_sdk @@ -6844,6 +6920,8 @@ def set_callbacks(callback_list, function_id=None): promptLayerLogger = PromptLayerLogger() elif callback == "langfuse": langFuseLogger = LangFuseLogger() + elif callback == "openmeter": + openMeterLogger = OpenMeterLogger() elif callback == "datadog": dataDogLogger = DataDogLogger() elif callback == "prometheus":