Merge pull request #4924 from BerriAI/litellm_log_writing_spend_to_db_otel

[Feat] - log writing BatchSpendUpdate events on OTEL
This commit is contained in:
Ishaan Jaff 2024-07-27 16:07:56 -07:00 committed by GitHub
commit 2e9fb5ca1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 87 additions and 32 deletions

View file

@ -56,6 +56,7 @@ class ServiceLogging(CustomLogger):
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[datetime, float]] = None,
event_metadata: Optional[dict] = None,
):
"""
- For counting if the redis, postgres call is successful
@ -84,6 +85,7 @@ class ServiceLogging(CustomLogger):
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
event_metadata=event_metadata,
)
async def async_service_failure_hook(

View file

@ -21,6 +21,7 @@ from openai._models import BaseModel as OpenAIObject
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
@ -33,16 +34,6 @@ def print_verbose(print_statement):
pass
def _get_parent_otel_span_from_kwargs(kwargs: Optional[dict] = None):
try:
if kwargs is None:
return None
_metadata = kwargs.get("metadata") or {}
return _metadata.get("litellm_parent_otel_span")
except:
return None
class BaseCache:
def set_cache(self, key, value, **kwargs):
raise NotImplementedError

View file

@ -119,6 +119,7 @@ class OpenTelemetry(CustomLogger):
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[datetime, float]] = None,
event_metadata: Optional[dict] = None,
):
from datetime import datetime
@ -149,6 +150,10 @@ class OpenTelemetry(CustomLogger):
service_logging_span.set_attribute(
key="service", value=payload.service.value
)
if event_metadata:
for key, value in event_metadata.items():
service_logging_span.set_attribute(key, value)
service_logging_span.set_status(Status(StatusCode.OK))
service_logging_span.end(end_time=_end_time_ns)

View file

@ -1,5 +1,6 @@
# What is this?
## Helper utilities
from typing import List, Literal, Optional, Tuple
def map_finish_reason(
@ -54,3 +55,31 @@ def remove_index_from_tool_calls(messages, tool_calls):
tool_call.pop("index")
return
def get_litellm_metadata_from_kwargs(kwargs: dict):
"""
Helper to get litellm metadata from all litellm request kwargs
"""
return kwargs.get("litellm_params", {}).get("metadata", {})
# Helper functions used for OTEL logging
def _get_parent_otel_span_from_kwargs(kwargs: Optional[dict] = None):
try:
if kwargs is None:
return None
litellm_params = kwargs.get("litellm_params")
_metadata = kwargs.get("metadata") or {}
if "litellm_parent_otel_span" in _metadata:
return _metadata["litellm_parent_otel_span"]
elif (
litellm_params is not None
and litellm_params.get("metadata") is not None
and "litellm_parent_otel_span" in litellm_params.get("metadata", {})
):
return litellm_params["metadata"]["litellm_parent_otel_span"]
elif "litellm_parent_otel_span" in kwargs:
return kwargs["litellm_parent_otel_span"]
except:
return None

View file

@ -27,12 +27,6 @@ model_list:
mode: audio_speech
general_settings:
master_key: sk-1234
alerting: ["slack"]
alerting_threshold: 0.0001
alert_to_webhook_url: {
"llm_too_slow": "https://hooks.slack.com/services/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",
"llm_requests_hanging": "https://hooks.slack.com/services/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",
}
litellm_settings:
success_callback: ["langfuse"]
callbacks: ["otel"]

View file

@ -108,6 +108,7 @@ from litellm._logging import verbose_proxy_logger, verbose_router_logger
from litellm.caching import DualCache, RedisCache
from litellm.exceptions import RejectedRequestError
from litellm.integrations.slack_alerting import SlackAlerting, SlackAlertingArgs
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import *
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
@ -203,6 +204,7 @@ from litellm.proxy.utils import (
get_error_message_str,
get_instance_fn,
hash_token,
log_to_opentelemetry,
reset_budget,
send_email,
update_spend,
@ -649,6 +651,7 @@ async def _PROXY_failure_handler(
pass
@log_to_opentelemetry
async def _PROXY_track_cost_callback(
kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion
@ -670,18 +673,15 @@ async def _PROXY_track_cost_callback(
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = kwargs["litellm_params"]["metadata"].get("user_api_key_user_id", None)
team_id = kwargs["litellm_params"]["metadata"].get("user_api_key_team_id", None)
org_id = kwargs["litellm_params"]["metadata"].get("user_api_key_org_id", None)
key_alias = kwargs["litellm_params"]["metadata"].get("user_api_key_alias", None)
end_user_max_budget = kwargs["litellm_params"]["metadata"].get(
"user_api_end_user_max_budget", None
)
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
user_id = metadata.get("user_api_key_user_id", None)
team_id = metadata.get("user_api_key_team_id", None)
org_id = metadata.get("user_api_key_org_id", None)
key_alias = metadata.get("user_api_key_alias", None)
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
if kwargs.get("response_cost", None) is not None:
response_cost = kwargs["response_cost"]
user_api_key = kwargs["litellm_params"]["metadata"].get(
"user_api_key", None
)
user_api_key = metadata.get("user_api_key", None)
if kwargs.get("cache_hit", False) == True:
response_cost = 0.0

View file

@ -32,6 +32,10 @@ from litellm.caching import DualCache, RedisCache
from litellm.exceptions import RejectedRequestError
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.slack_alerting import SlackAlerting
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs,
)
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy._types import (
@ -125,6 +129,29 @@ def log_to_opentelemetry(func):
start_time=start_time,
end_time=end_time,
)
elif (
# in litellm custom callbacks kwargs is passed as arg[0]
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
args is not None
and len(args) > 0
):
passed_kwargs = args[0]
parent_otel_span = _get_parent_otel_span_from_kwargs(
kwargs=passed_kwargs
)
if parent_otel_span is not None:
from litellm.proxy.proxy_server import proxy_logging_obj
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
await proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.BATCH_WRITE_TO_DB,
call_type=func.__name__,
parent_otel_span=parent_otel_span,
duration=0.0,
start_time=start_time,
end_time=end_time,
event_metadata=metadata,
)
# end of logging to otel
return result
except Exception as e:

View file

@ -1,7 +1,9 @@
import uuid, enum
from pydantic import BaseModel, Field
import enum
import uuid
from typing import Optional
from pydantic import BaseModel, Field
class ServiceTypes(str, enum.Enum):
"""
@ -10,6 +12,7 @@ class ServiceTypes(str, enum.Enum):
REDIS = "redis"
DB = "postgres"
BATCH_WRITE_TO_DB = "batch_write_to_db"
LITELLM = "self"

View file

@ -5,6 +5,7 @@ import asyncio
import aiohttp, openai
from openai import OpenAI, AsyncOpenAI
from typing import Optional, List, Union
import uuid
async def generate_key(
@ -46,7 +47,7 @@ async def chat_completion(session, key, model: Union[str, List] = "gpt-4"):
data = {
"model": model,
"messages": [
{"role": "user", "content": "Hello!"},
{"role": "user", "content": f"Hello! {str(uuid.uuid4())}"},
],
}
@ -96,6 +97,8 @@ async def test_chat_completion_check_otel_spans():
key = key_gen["key"]
await chat_completion(session=session, key=key, model="fake-openai-endpoint")
await asyncio.sleep(3)
otel_spans = await get_otel_spans(session=session, key=key)
print("otel_spans: ", otel_spans)
@ -107,11 +110,12 @@ async def test_chat_completion_check_otel_spans():
print("Parent trace spans: ", parent_trace_spans)
# either 4 or 5 traces depending on how many redis calls were made
assert len(parent_trace_spans) == 5 or len(parent_trace_spans) == 4
# either 5 or 6 traces depending on how many redis calls were made
assert len(parent_trace_spans) == 6 or len(parent_trace_spans) == 5
# 'postgres', 'redis', 'raw_gen_ai_request', 'litellm_request', 'Received Proxy Server Request' in the span
assert "postgres" in parent_trace_spans
assert "redis" in parent_trace_spans
assert "raw_gen_ai_request" in parent_trace_spans
assert "litellm_request" in parent_trace_spans
assert "batch_write_to_db" in parent_trace_spans