Merge pull request #4924 from BerriAI/litellm_log_writing_spend_to_db_otel

[Feat] - log writing BatchSpendUpdate events on OTEL
This commit is contained in:
Ishaan Jaff 2024-07-27 16:07:56 -07:00 committed by GitHub
commit 2e9fb5ca1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 87 additions and 32 deletions

View file

@ -56,6 +56,7 @@ class ServiceLogging(CustomLogger):
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None, start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[datetime, float]] = None, end_time: Optional[Union[datetime, float]] = None,
event_metadata: Optional[dict] = None,
): ):
""" """
- For counting if the redis, postgres call is successful - For counting if the redis, postgres call is successful
@ -84,6 +85,7 @@ class ServiceLogging(CustomLogger):
parent_otel_span=parent_otel_span, parent_otel_span=parent_otel_span,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
event_metadata=event_metadata,
) )
async def async_service_failure_hook( async def async_service_failure_hook(

View file

@ -21,6 +21,7 @@ from openai._models import BaseModel as OpenAIObject
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.services import ServiceLoggerPayload, ServiceTypes from litellm.types.services import ServiceLoggerPayload, ServiceTypes
@ -33,16 +34,6 @@ def print_verbose(print_statement):
pass pass
def _get_parent_otel_span_from_kwargs(kwargs: Optional[dict] = None):
try:
if kwargs is None:
return None
_metadata = kwargs.get("metadata") or {}
return _metadata.get("litellm_parent_otel_span")
except:
return None
class BaseCache: class BaseCache:
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
raise NotImplementedError raise NotImplementedError

View file

@ -119,6 +119,7 @@ class OpenTelemetry(CustomLogger):
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None, start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[datetime, float]] = None, end_time: Optional[Union[datetime, float]] = None,
event_metadata: Optional[dict] = None,
): ):
from datetime import datetime from datetime import datetime
@ -149,6 +150,10 @@ class OpenTelemetry(CustomLogger):
service_logging_span.set_attribute( service_logging_span.set_attribute(
key="service", value=payload.service.value key="service", value=payload.service.value
) )
if event_metadata:
for key, value in event_metadata.items():
service_logging_span.set_attribute(key, value)
service_logging_span.set_status(Status(StatusCode.OK)) service_logging_span.set_status(Status(StatusCode.OK))
service_logging_span.end(end_time=_end_time_ns) service_logging_span.end(end_time=_end_time_ns)

View file

@ -1,5 +1,6 @@
# What is this? # What is this?
## Helper utilities ## Helper utilities
from typing import List, Literal, Optional, Tuple
def map_finish_reason( def map_finish_reason(
@ -54,3 +55,31 @@ def remove_index_from_tool_calls(messages, tool_calls):
tool_call.pop("index") tool_call.pop("index")
return return
def get_litellm_metadata_from_kwargs(kwargs: dict):
"""
Helper to get litellm metadata from all litellm request kwargs
"""
return kwargs.get("litellm_params", {}).get("metadata", {})
# Helper functions used for OTEL logging
def _get_parent_otel_span_from_kwargs(kwargs: Optional[dict] = None):
try:
if kwargs is None:
return None
litellm_params = kwargs.get("litellm_params")
_metadata = kwargs.get("metadata") or {}
if "litellm_parent_otel_span" in _metadata:
return _metadata["litellm_parent_otel_span"]
elif (
litellm_params is not None
and litellm_params.get("metadata") is not None
and "litellm_parent_otel_span" in litellm_params.get("metadata", {})
):
return litellm_params["metadata"]["litellm_parent_otel_span"]
elif "litellm_parent_otel_span" in kwargs:
return kwargs["litellm_parent_otel_span"]
except:
return None

View file

@ -27,12 +27,6 @@ model_list:
mode: audio_speech mode: audio_speech
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234
alerting: ["slack"]
alerting_threshold: 0.0001
alert_to_webhook_url: {
"llm_too_slow": "https://hooks.slack.com/services/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",
"llm_requests_hanging": "https://hooks.slack.com/services/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",
}
litellm_settings: litellm_settings:
success_callback: ["langfuse"] callbacks: ["otel"]

View file

@ -108,6 +108,7 @@ from litellm._logging import verbose_proxy_logger, verbose_router_logger
from litellm.caching import DualCache, RedisCache from litellm.caching import DualCache, RedisCache
from litellm.exceptions import RejectedRequestError from litellm.exceptions import RejectedRequestError
from litellm.integrations.slack_alerting import SlackAlerting, SlackAlertingArgs from litellm.integrations.slack_alerting import SlackAlerting, SlackAlertingArgs
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.proxy.analytics_endpoints.analytics_endpoints import ( from litellm.proxy.analytics_endpoints.analytics_endpoints import (
@ -203,6 +204,7 @@ from litellm.proxy.utils import (
get_error_message_str, get_error_message_str,
get_instance_fn, get_instance_fn,
hash_token, hash_token,
log_to_opentelemetry,
reset_budget, reset_budget,
send_email, send_email,
update_spend, update_spend,
@ -649,6 +651,7 @@ async def _PROXY_failure_handler(
pass pass
@log_to_opentelemetry
async def _PROXY_track_cost_callback( async def _PROXY_track_cost_callback(
kwargs, # kwargs to completion kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion completion_response: litellm.ModelResponse, # response from completion
@ -670,18 +673,15 @@ async def _PROXY_track_cost_callback(
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {} proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None) end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None) metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None) user_id = metadata.get("user_api_key_user_id", None)
org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None) team_id = metadata.get("user_api_key_team_id", None)
key_alias = kwargs["litellm_params"]["metadata"].get("user_api_key_alias", None) org_id = metadata.get("user_api_key_org_id", None)
end_user_max_budget = kwargs["litellm_params"]["metadata"].get( key_alias = metadata.get("user_api_key_alias", None)
"user_api_end_user_max_budget", None end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
)
if kwargs.get("response_cost", None) is not None: if kwargs.get("response_cost", None) is not None:
response_cost = kwargs["response_cost"] response_cost = kwargs["response_cost"]
user_api_key = kwargs["litellm_params"]["metadata"].get( user_api_key = metadata.get("user_api_key", None)
"user_api_key", None
)
if kwargs.get("cache_hit", False) == True: if kwargs.get("cache_hit", False) == True:
response_cost = 0.0 response_cost = 0.0

View file

@ -32,6 +32,10 @@ from litellm.caching import DualCache, RedisCache
from litellm.exceptions import RejectedRequestError from litellm.exceptions import RejectedRequestError
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.slack_alerting import SlackAlerting from litellm.integrations.slack_alerting import SlackAlerting
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs,
)
from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import ( from litellm.proxy._types import (
@ -125,6 +129,29 @@ def log_to_opentelemetry(func):
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) )
elif (
# in litellm custom callbacks kwargs is passed as arg[0]
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
args is not None
and len(args) > 0
):
passed_kwargs = args[0]
parent_otel_span = _get_parent_otel_span_from_kwargs(
kwargs=passed_kwargs
)
if parent_otel_span is not None:
from litellm.proxy.proxy_server import proxy_logging_obj
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.BATCH_WRITE_TO_DB,
call_type=func.__name__,
parent_otel_span=parent_otel_span,
duration=0.0,
start_time=start_time,
end_time=end_time,
event_metadata=metadata,
)
# end of logging to otel # end of logging to otel
return result return result
except Exception as e: except Exception as e:

View file

@ -1,7 +1,9 @@
import uuid, enum import enum
from pydantic import BaseModel, Field import uuid
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field
class ServiceTypes(str, enum.Enum): class ServiceTypes(str, enum.Enum):
""" """
@ -10,6 +12,7 @@ class ServiceTypes(str, enum.Enum):
REDIS = "redis" REDIS = "redis"
DB = "postgres" DB = "postgres"
BATCH_WRITE_TO_DB = "batch_write_to_db"
LITELLM = "self" LITELLM = "self"

View file

@ -5,6 +5,7 @@ import asyncio
import aiohttp, openai import aiohttp, openai
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI
from typing import Optional, List, Union from typing import Optional, List, Union
import uuid
async def generate_key( async def generate_key(
@ -46,7 +47,7 @@ async def chat_completion(session, key, model: Union[str, List] = "gpt-4"):
data = { data = {
"model": model, "model": model,
"messages": [ "messages": [
{"role": "user", "content": "Hello!"}, {"role": "user", "content": f"Hello! {str(uuid.uuid4())}"},
], ],
} }
@ -96,6 +97,8 @@ async def test_chat_completion_check_otel_spans():
key = key_gen["key"] key = key_gen["key"]
await chat_completion(session=session, key=key, model="fake-openai-endpoint") await chat_completion(session=session, key=key, model="fake-openai-endpoint")
await asyncio.sleep(3)
otel_spans = await get_otel_spans(session=session, key=key) otel_spans = await get_otel_spans(session=session, key=key)
print("otel_spans: ", otel_spans) print("otel_spans: ", otel_spans)
@ -107,11 +110,12 @@ async def test_chat_completion_check_otel_spans():
print("Parent trace spans: ", parent_trace_spans) print("Parent trace spans: ", parent_trace_spans)
# either 4 or 5 traces depending on how many redis calls were made # either 5 or 6 traces depending on how many redis calls were made
assert len(parent_trace_spans) == 5 or len(parent_trace_spans) == 4 assert len(parent_trace_spans) == 6 or len(parent_trace_spans) == 5
# 'postgres', 'redis', 'raw_gen_ai_request', 'litellm_request', 'Received Proxy Server Request' in the span # 'postgres', 'redis', 'raw_gen_ai_request', 'litellm_request', 'Received Proxy Server Request' in the span
assert "postgres" in parent_trace_spans assert "postgres" in parent_trace_spans
assert "redis" in parent_trace_spans assert "redis" in parent_trace_spans
assert "raw_gen_ai_request" in parent_trace_spans assert "raw_gen_ai_request" in parent_trace_spans
assert "litellm_request" in parent_trace_spans assert "litellm_request" in parent_trace_spans
assert "batch_write_to_db" in parent_trace_spans