mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(openmeter.py): add support for user billing
open-meter supports user based billing. Closes https://github.com/BerriAI/litellm/issues/1268
This commit is contained in:
parent
37eb7910d2
commit
2a9651b3ca
5 changed files with 210 additions and 13 deletions
|
@ -22,6 +22,7 @@ success_callback: List[Union[str, Callable]] = []
|
||||||
failure_callback: List[Union[str, Callable]] = []
|
failure_callback: List[Union[str, Callable]] = []
|
||||||
service_callback: List[Union[str, Callable]] = []
|
service_callback: List[Union[str, Callable]] = []
|
||||||
callbacks: List[Callable] = []
|
callbacks: List[Callable] = []
|
||||||
|
_custom_logger_compatible_callbacks: list = ["openmeter"]
|
||||||
_langfuse_default_tags: Optional[
|
_langfuse_default_tags: Optional[
|
||||||
List[
|
List[
|
||||||
Literal[
|
Literal[
|
||||||
|
|
122
litellm/integrations/openmeter.py
Normal file
122
litellm/integrations/openmeter.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
# What is this?
|
||||||
|
## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268
|
||||||
|
|
||||||
|
import dotenv, os, json
|
||||||
|
import requests
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
|
import traceback
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
def get_utc_datetime():
|
||||||
|
import datetime as dt
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
if hasattr(dt, "UTC"):
|
||||||
|
return datetime.now(dt.UTC) # type: ignore
|
||||||
|
else:
|
||||||
|
return datetime.utcnow() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class OpenMeterLogger(CustomLogger):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.validate_environment()
|
||||||
|
self.async_http_handler = AsyncHTTPHandler()
|
||||||
|
self.sync_http_handler = HTTPHandler()
|
||||||
|
|
||||||
|
def validate_environment(self):
|
||||||
|
"""
|
||||||
|
Expects
|
||||||
|
OPENMETER_API_ENDPOINT,
|
||||||
|
OPENMETER_API_KEY,
|
||||||
|
|
||||||
|
in the environment
|
||||||
|
"""
|
||||||
|
missing_keys = []
|
||||||
|
if litellm.get_secret("OPENMETER_API_ENDPOINT", None) is None:
|
||||||
|
missing_keys.append("OPENMETER_API_ENDPOINT")
|
||||||
|
|
||||||
|
if litellm.get_secret("OPENMETER_API_KEY", None) is None:
|
||||||
|
missing_keys.append("OPENMETER_API_KEY")
|
||||||
|
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
raise Exception("Missing keys={} in environment.".format(missing_keys))
|
||||||
|
|
||||||
|
def _common_logic(self, kwargs: dict, response_obj):
|
||||||
|
call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||||
|
dt = get_utc_datetime().isoformat()
|
||||||
|
cost = kwargs.get("response_cost", None)
|
||||||
|
model = kwargs.get("model")
|
||||||
|
usage = {}
|
||||||
|
if (
|
||||||
|
isinstance(response_obj, litellm.ModelResponse)
|
||||||
|
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||||
|
) and hasattr(response_obj, "usage"):
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": response_obj["usage"].get("prompt_tokens", 0),
|
||||||
|
"completion_tokens": response_obj["usage"].get("completion_tokens", 0),
|
||||||
|
"total_tokens": response_obj["usage"].get("total_tokens"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"specversion": "1.0",
|
||||||
|
"type": os.getenv("OPENMETER_EVENT_TYPE", "litellm_tokens"),
|
||||||
|
"id": call_id,
|
||||||
|
"time": dt,
|
||||||
|
"subject": kwargs.get("user", ""), # end-user passed in via 'user' param
|
||||||
|
"source": "litellm-proxy",
|
||||||
|
"data": {"model": model, "cost": cost, **usage},
|
||||||
|
}
|
||||||
|
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
_url = litellm.get_secret("OPENMETER_API_ENDPOINT")
|
||||||
|
if _url.endswith("/"):
|
||||||
|
_url += "api/v1/events"
|
||||||
|
else:
|
||||||
|
_url += "/api/v1/events"
|
||||||
|
|
||||||
|
api_key = litellm.get_secret("OPENMETER_API_KEY")
|
||||||
|
|
||||||
|
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||||
|
self.sync_http_handler.post(
|
||||||
|
url=_url,
|
||||||
|
data=_data,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/cloudevents+json",
|
||||||
|
"Authorization": "Bearer {}".format(api_key),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
_url = litellm.get_secret("OPENMETER_API_ENDPOINT")
|
||||||
|
if _url.endswith("/"):
|
||||||
|
_url += "api/v1/events"
|
||||||
|
else:
|
||||||
|
_url += "/api/v1/events"
|
||||||
|
|
||||||
|
api_key = litellm.get_secret("OPENMETER_API_KEY")
|
||||||
|
|
||||||
|
_data = self._common_logic(kwargs=kwargs, response_obj=response_obj)
|
||||||
|
_headers = {
|
||||||
|
"Content-Type": "application/cloudevents+json",
|
||||||
|
"Authorization": "Bearer {}".format(api_key),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.async_http_handler.post(
|
||||||
|
url=_url,
|
||||||
|
data=json.dumps(_data),
|
||||||
|
headers=_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nAn Exception Occurred - {str(e)}")
|
||||||
|
if hasattr(response, "text"):
|
||||||
|
print(f"\nError Message: {response.text}")
|
||||||
|
raise e
|
|
@ -1,15 +1,8 @@
|
||||||
model_list:
|
model_list:
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
api_base: http://0.0.0.0:8080
|
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
||||||
api_key: my-fake-key
|
api_key: my-fake-key
|
||||||
model: openai/my-fake-model
|
model: openai/my-fake-model
|
||||||
rpm: 100
|
|
||||||
model_name: fake-openai-endpoint
|
|
||||||
- litellm_params:
|
|
||||||
api_base: http://0.0.0.0:8081
|
|
||||||
api_key: my-fake-key
|
|
||||||
model: openai/my-fake-model-2
|
|
||||||
rpm: 100
|
|
||||||
model_name: fake-openai-endpoint
|
model_name: fake-openai-endpoint
|
||||||
router_settings:
|
router_settings:
|
||||||
num_retries: 0
|
num_retries: 0
|
||||||
|
@ -17,3 +10,6 @@ router_settings:
|
||||||
redis_host: os.environ/REDIS_HOST
|
redis_host: os.environ/REDIS_HOST
|
||||||
redis_password: os.environ/REDIS_PASSWORD
|
redis_password: os.environ/REDIS_PASSWORD
|
||||||
redis_port: os.environ/REDIS_PORT
|
redis_port: os.environ/REDIS_PORT
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
success_callback: ["openmeter"]
|
|
@ -1777,7 +1777,7 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
||||||
usage = response_obj["usage"]
|
usage = response_obj["usage"]
|
||||||
if type(usage) == litellm.Usage:
|
if type(usage) == litellm.Usage:
|
||||||
usage = dict(usage)
|
usage = dict(usage)
|
||||||
id = response_obj.get("id", str(uuid.uuid4()))
|
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||||
api_key = metadata.get("user_api_key", "")
|
api_key = metadata.get("user_api_key", "")
|
||||||
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
|
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
|
||||||
# hash the api_key
|
# hash the api_key
|
||||||
|
|
|
@ -70,6 +70,7 @@ from .integrations.langsmith import LangsmithLogger
|
||||||
from .integrations.weights_biases import WeightsBiasesLogger
|
from .integrations.weights_biases import WeightsBiasesLogger
|
||||||
from .integrations.custom_logger import CustomLogger
|
from .integrations.custom_logger import CustomLogger
|
||||||
from .integrations.langfuse import LangFuseLogger
|
from .integrations.langfuse import LangFuseLogger
|
||||||
|
from .integrations.openmeter import OpenMeterLogger
|
||||||
from .integrations.datadog import DataDogLogger
|
from .integrations.datadog import DataDogLogger
|
||||||
from .integrations.prometheus import PrometheusLogger
|
from .integrations.prometheus import PrometheusLogger
|
||||||
from .integrations.prometheus_services import PrometheusServicesLogger
|
from .integrations.prometheus_services import PrometheusServicesLogger
|
||||||
|
@ -130,6 +131,7 @@ langsmithLogger = None
|
||||||
weightsBiasesLogger = None
|
weightsBiasesLogger = None
|
||||||
customLogger = None
|
customLogger = None
|
||||||
langFuseLogger = None
|
langFuseLogger = None
|
||||||
|
openMeterLogger = None
|
||||||
dataDogLogger = None
|
dataDogLogger = None
|
||||||
prometheusLogger = None
|
prometheusLogger = None
|
||||||
dynamoLogger = None
|
dynamoLogger = None
|
||||||
|
@ -1922,6 +1924,51 @@ class Logging:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
callback == "openmeter"
|
||||||
|
and self.model_call_details.get("litellm_params", {}).get(
|
||||||
|
"acompletion", False
|
||||||
|
)
|
||||||
|
== False
|
||||||
|
and self.model_call_details.get("litellm_params", {}).get(
|
||||||
|
"aembedding", False
|
||||||
|
)
|
||||||
|
== False
|
||||||
|
and self.model_call_details.get("litellm_params", {}).get(
|
||||||
|
"aimage_generation", False
|
||||||
|
)
|
||||||
|
== False
|
||||||
|
and self.model_call_details.get("litellm_params", {}).get(
|
||||||
|
"atranscription", False
|
||||||
|
)
|
||||||
|
== False
|
||||||
|
):
|
||||||
|
global openMeterLogger
|
||||||
|
if openMeterLogger is None:
|
||||||
|
print_verbose("Instantiates openmeter client")
|
||||||
|
openMeterLogger = OpenMeterLogger()
|
||||||
|
if self.stream and complete_streaming_response is None:
|
||||||
|
openMeterLogger.log_stream_event(
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
response_obj=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.stream and complete_streaming_response:
|
||||||
|
self.model_call_details["complete_response"] = (
|
||||||
|
self.model_call_details.get(
|
||||||
|
"complete_streaming_response", {}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = self.model_call_details["complete_response"]
|
||||||
|
openMeterLogger.log_success_event(
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
response_obj=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(callback, CustomLogger)
|
isinstance(callback, CustomLogger)
|
||||||
and self.model_call_details.get("litellm_params", {}).get(
|
and self.model_call_details.get("litellm_params", {}).get(
|
||||||
|
@ -2121,6 +2168,35 @@ class Logging:
|
||||||
await litellm.cache.async_add_cache(result, **kwargs)
|
await litellm.cache.async_add_cache(result, **kwargs)
|
||||||
else:
|
else:
|
||||||
litellm.cache.add_cache(result, **kwargs)
|
litellm.cache.add_cache(result, **kwargs)
|
||||||
|
if callback == "openmeter":
|
||||||
|
global openMeterLogger
|
||||||
|
if self.stream == True:
|
||||||
|
if (
|
||||||
|
"async_complete_streaming_response"
|
||||||
|
in self.model_call_details
|
||||||
|
):
|
||||||
|
await openMeterLogger.async_log_success_event(
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
response_obj=self.model_call_details[
|
||||||
|
"async_complete_streaming_response"
|
||||||
|
],
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await openMeterLogger.async_log_stream_event( # [TODO]: move this to being an async log stream event function
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
response_obj=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await openMeterLogger.async_log_success_event(
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
response_obj=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
)
|
||||||
if isinstance(callback, CustomLogger): # custom logger class
|
if isinstance(callback, CustomLogger): # custom logger class
|
||||||
if self.stream == True:
|
if self.stream == True:
|
||||||
if (
|
if (
|
||||||
|
@ -2594,7 +2670,7 @@ def function_setup(
|
||||||
if inspect.iscoroutinefunction(callback):
|
if inspect.iscoroutinefunction(callback):
|
||||||
litellm._async_success_callback.append(callback)
|
litellm._async_success_callback.append(callback)
|
||||||
removed_async_items.append(index)
|
removed_async_items.append(index)
|
||||||
elif callback == "dynamodb":
|
elif callback == "dynamodb" or callback == "openmeter":
|
||||||
# dynamo is an async callback, it's used for the proxy and needs to be async
|
# dynamo is an async callback, it's used for the proxy and needs to be async
|
||||||
# we only support async dynamo db logging for acompletion/aembedding since that's used on proxy
|
# we only support async dynamo db logging for acompletion/aembedding since that's used on proxy
|
||||||
litellm._async_success_callback.append(callback)
|
litellm._async_success_callback.append(callback)
|
||||||
|
@ -6777,11 +6853,11 @@ def validate_environment(model: Optional[str] = None) -> dict:
|
||||||
|
|
||||||
def set_callbacks(callback_list, function_id=None):
|
def set_callbacks(callback_list, function_id=None):
|
||||||
|
|
||||||
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger
|
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for callback in callback_list:
|
for callback in callback_list:
|
||||||
print_verbose(f"callback: {callback}")
|
print_verbose(f"init callback list: {callback}")
|
||||||
if callback == "sentry":
|
if callback == "sentry":
|
||||||
try:
|
try:
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
|
@ -6844,6 +6920,8 @@ def set_callbacks(callback_list, function_id=None):
|
||||||
promptLayerLogger = PromptLayerLogger()
|
promptLayerLogger = PromptLayerLogger()
|
||||||
elif callback == "langfuse":
|
elif callback == "langfuse":
|
||||||
langFuseLogger = LangFuseLogger()
|
langFuseLogger = LangFuseLogger()
|
||||||
|
elif callback == "openmeter":
|
||||||
|
openMeterLogger = OpenMeterLogger()
|
||||||
elif callback == "datadog":
|
elif callback == "datadog":
|
||||||
dataDogLogger = DataDogLogger()
|
dataDogLogger = DataDogLogger()
|
||||||
elif callback == "prometheus":
|
elif callback == "prometheus":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue