Merge branch 'main' into litellm_dynamic_tpm_limits

This commit is contained in:
Krish Dholakia 2024-06-22 19:14:59 -07:00 committed by GitHub
commit 961e7ac95d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 689 additions and 1186 deletions

View file

@ -10,7 +10,7 @@ import sys
import time
import traceback
import uuid
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Literal, Optional
import litellm
from litellm import (
@ -19,7 +19,8 @@ from litellm import (
turn_off_message_logging,
verbose_logger,
)
from litellm.caching import DualCache, S3Cache
from litellm.caching import InMemoryCache, S3Cache, DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging,
@ -111,7 +112,25 @@ additional_details: Optional[Dict[str, str]] = {}
local_cache: Optional[Dict[str, str]] = {}
last_fetched_at = None
last_fetched_at_keys = None
####
class ServiceTraceIDCache:
def __init__(self) -> None:
self.cache = InMemoryCache()
def get_cache(self, litellm_call_id: str, service_name: str) -> Optional[str]:
key_name = "{}:{}".format(service_name, litellm_call_id)
response = self.cache.get_cache(key=key_name)
return response
def set_cache(self, litellm_call_id: str, service_name: str, trace_id: str) -> None:
key_name = "{}:{}".format(service_name, litellm_call_id)
self.cache.set_cache(key=key_name, value=trace_id)
return None
in_memory_trace_id_cache = ServiceTraceIDCache()
class Logging:
@ -155,7 +174,7 @@ class Logging:
new_messages.append({"role": "user", "content": m})
messages = new_messages
self.model = model
self.messages = messages
self.messages = copy.deepcopy(messages)
self.stream = stream
self.start_time = start_time # log the call start time
self.call_type = call_type
@ -245,10 +264,17 @@ class Logging:
if headers is None:
headers = {}
data = additional_args.get("complete_input_dict", {})
api_base = additional_args.get("api_base", "")
self.model_call_details["litellm_params"]["api_base"] = str(
api_base
) # used for alerting
api_base = str(additional_args.get("api_base", ""))
if "key=" in api_base:
# Find the position of "key=" in the string
key_index = api_base.find("key=") + 4
# Mask the last 5 characters after "key="
masked_api_base = (
api_base[:key_index] + "*" * 5 + api_base[key_index + 5 :]
)
else:
masked_api_base = api_base
self.model_call_details["litellm_params"]["api_base"] = masked_api_base
masked_headers = {
k: (
(v[:-44] + "*" * 44)
@ -821,7 +847,7 @@ class Logging:
langfuse_secret=self.langfuse_secret,
langfuse_host=self.langfuse_host,
)
langFuseLogger.log_event(
_response = langFuseLogger.log_event(
kwargs=kwargs,
response_obj=result,
start_time=start_time,
@ -829,6 +855,14 @@ class Logging:
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if _response is not None and isinstance(_response, dict):
_trace_id = _response.get("trace_id", None)
if _trace_id is not None:
in_memory_trace_id_cache.set_cache(
litellm_call_id=self.litellm_call_id,
service_name="langfuse",
trace_id=_trace_id,
)
if callback == "datadog":
global dataDogLogger
verbose_logger.debug("reaches datadog for success logging!")
@ -1607,7 +1641,7 @@ class Logging:
langfuse_secret=self.langfuse_secret,
langfuse_host=self.langfuse_host,
)
langFuseLogger.log_event(
_response = langFuseLogger.log_event(
start_time=start_time,
end_time=end_time,
response_obj=None,
@ -1617,6 +1651,14 @@ class Logging:
level="ERROR",
kwargs=self.model_call_details,
)
if _response is not None and isinstance(_response, dict):
_trace_id = _response.get("trace_id", None)
if _trace_id is not None:
in_memory_trace_id_cache.set_cache(
litellm_call_id=self.litellm_call_id,
service_name="langfuse",
trace_id=_trace_id,
)
if callback == "traceloop":
traceloopLogger.log_event(
start_time=start_time,
@ -1721,6 +1763,24 @@ class Logging:
)
)
def _get_trace_id(self, service_name: Literal["langfuse"]) -> Optional[str]:
"""
For the given service (e.g. langfuse), return the trace_id actually logged.
Used for constructing the url in slack alerting.
Returns:
- str: The logged trace id
- None: If trace id not yet emitted.
"""
trace_id: Optional[str] = None
if service_name == "langfuse":
trace_id = in_memory_trace_id_cache.get_cache(
litellm_call_id=self.litellm_call_id, service_name=service_name
)
return trace_id
def set_callbacks(callback_list, function_id=None):
"""