LiteLLM Minor Fixes & Improvements (09/16/2024) (#5723) (#5731)

* LiteLLM Minor Fixes & Improvements (09/16/2024)  (#5723)

* coverage (#5713)

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Move (#5714)

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix(litellm_logging.py): fix logging client re-init (#5710)

Fixes https://github.com/BerriAI/litellm/issues/5695

* fix(presidio.py): Fix logging_hook response and add support for additional presidio variables in guardrails config

Fixes https://github.com/BerriAI/litellm/issues/5682

* feat(o1_handler.py): fake streaming for openai o1 models

Fixes https://github.com/BerriAI/litellm/issues/5694

* docs: deprecated traceloop integration in favor of native otel (#5249)

* fix: fix linting errors

* fix: fix linting errors

* fix(main.py): fix o1 import

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>

* feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view (#5730)

* feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view

Supports having `MonthlyGlobalSpend` view be a material view, and exposes an endpoint to refresh it

* fix(custom_logger.py): reset calltype

* fix: fix linting errors

* fix: fix linting error

* fix: fix import

* test(test_databricks.py): fix databricks tests

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-09-17 08:05:52 -07:00 committed by GitHub
parent 1e59395280
commit 234185ec13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 1387 additions and 502 deletions

View file

@ -90,6 +90,13 @@ from ..integrations.supabase import Supabase
from ..integrations.traceloop import TraceloopLogger
from ..integrations.weights_biases import WeightsBiasesLogger
try:
from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import (
GenericAPILogger,
)
except Exception as e:
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
_in_memory_loggers: List[Any] = []
### GLOBAL VARIABLES ###
@ -145,7 +152,41 @@ class ServiceTraceIDCache:
return None
import hashlib
class DynamicLoggingCache:
"""
Prevent memory leaks caused by initializing new logging clients on each request.
Relevant Issue: https://github.com/BerriAI/litellm/issues/5695
"""
def __init__(self) -> None:
self.cache = InMemoryCache()
def get_cache_key(self, args: dict) -> str:
args_str = json.dumps(args, sort_keys=True)
cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest()
return cache_key
def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]:
key_name = self.get_cache_key(
args={**credentials, "service_name": service_name}
)
response = self.cache.get_cache(key=key_name)
return response
def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None:
key_name = self.get_cache_key(
args={**credentials, "service_name": service_name}
)
self.cache.set_cache(key=key_name, value=logging_obj)
return None
in_memory_trace_id_cache = ServiceTraceIDCache()
in_memory_dynamic_logger_cache = DynamicLoggingCache()
class Logging:
@ -324,10 +365,10 @@ class Logging:
print_verbose(f"\033[92m{curl_command}\033[0m\n", log_level="DEBUG")
# log raw request to provider (like LangFuse) -- if opted in.
if log_raw_request_response is True:
_litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {}
try:
# [Non-blocking Extra Debug Information in metadata]
_litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {}
if (
turn_off_message_logging is not None
and turn_off_message_logging is True
@ -362,7 +403,7 @@ class Logging:
callbacks = litellm.input_callback + self.dynamic_input_callbacks
for callback in callbacks:
try:
if callback == "supabase":
if callback == "supabase" and supabaseClient is not None:
verbose_logger.debug("reaches supabase for logging!")
model = self.model_call_details["model"]
messages = self.model_call_details["input"]
@ -396,7 +437,9 @@ class Logging:
messages=self.messages,
kwargs=self.model_call_details,
)
elif callable(callback): # custom logger functions
elif (
callable(callback) and customLogger is not None
): # custom logger functions
customLogger.log_input_event(
model=self.model,
messages=self.messages,
@ -615,7 +658,7 @@ class Logging:
self.model_call_details["litellm_params"]["metadata"][
"hidden_params"
] = result._hidden_params
] = getattr(result, "_hidden_params", {})
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
@ -645,6 +688,7 @@ class Logging:
litellm.max_budget
and self.stream is False
and result is not None
and isinstance(result, dict)
and "content" in result
):
time_diff = (end_time - start_time).total_seconds()
@ -652,7 +696,7 @@ class Logging:
litellm._current_cost += litellm.completion_cost(
model=self.model,
prompt="",
completion=result["content"],
completion=getattr(result, "content", ""),
total_time=float_diff,
)
@ -758,7 +802,7 @@ class Logging:
):
print_verbose("no-log request, skipping logging")
continue
if callback == "lite_debugger":
if callback == "lite_debugger" and liteDebuggerClient is not None:
print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
print_verbose(
@ -774,7 +818,7 @@ class Logging:
call_type=self.call_type,
stream=self.stream,
)
if callback == "promptlayer":
if callback == "promptlayer" and promptLayerLogger is not None:
print_verbose("reaches promptlayer for logging!")
promptLayerLogger.log_event(
kwargs=self.model_call_details,
@ -783,7 +827,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
if callback == "supabase":
if callback == "supabase" and supabaseClient is not None:
print_verbose("reaches supabase for logging!")
kwargs = self.model_call_details
@ -811,7 +855,7 @@ class Logging:
),
print_verbose=print_verbose,
)
if callback == "wandb":
if callback == "wandb" and weightsBiasesLogger is not None:
print_verbose("reaches wandb for logging!")
weightsBiasesLogger.log_event(
kwargs=self.model_call_details,
@ -820,8 +864,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
if callback == "logfire":
global logfireLogger
if callback == "logfire" and logfireLogger is not None:
verbose_logger.debug("reaches logfire for success logging!")
kwargs = {}
for k, v in self.model_call_details.items():
@ -844,10 +887,10 @@ class Logging:
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
level=LogfireLevel.INFO.value,
level=LogfireLevel.INFO.value, # type: ignore
)
if callback == "lunary":
if callback == "lunary" and lunaryLogger is not None:
print_verbose("reaches lunary for logging!")
model = self.model
kwargs = self.model_call_details
@ -882,7 +925,7 @@ class Logging:
run_id=self.litellm_call_id,
print_verbose=print_verbose,
)
if callback == "helicone":
if callback == "helicone" and heliconeLogger is not None:
print_verbose("reaches helicone for logging!")
model = self.model
messages = self.model_call_details["input"]
@ -924,6 +967,7 @@ class Logging:
else:
print_verbose("reaches langfuse for streaming logging!")
result = kwargs["complete_streaming_response"]
temp_langfuse_logger = langFuseLogger
if langFuseLogger is None or (
(
@ -941,27 +985,45 @@ class Logging:
and self.langfuse_host != langFuseLogger.langfuse_host
)
):
temp_langfuse_logger = LangFuseLogger(
langfuse_public_key=self.langfuse_public_key,
langfuse_secret=self.langfuse_secret,
langfuse_host=self.langfuse_host,
)
_response = temp_langfuse_logger.log_event(
kwargs=kwargs,
response_obj=result,
start_time=start_time,
end_time=end_time,
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if _response is not None and isinstance(_response, dict):
_trace_id = _response.get("trace_id", None)
if _trace_id is not None:
in_memory_trace_id_cache.set_cache(
litellm_call_id=self.litellm_call_id,
service_name="langfuse",
trace_id=_trace_id,
credentials = {
"langfuse_public_key": self.langfuse_public_key,
"langfuse_secret": self.langfuse_secret,
"langfuse_host": self.langfuse_host,
}
temp_langfuse_logger = (
in_memory_dynamic_logger_cache.get_cache(
credentials=credentials, service_name="langfuse"
)
)
if temp_langfuse_logger is None:
temp_langfuse_logger = LangFuseLogger(
langfuse_public_key=self.langfuse_public_key,
langfuse_secret=self.langfuse_secret,
langfuse_host=self.langfuse_host,
)
in_memory_dynamic_logger_cache.set_cache(
credentials=credentials,
service_name="langfuse",
logging_obj=temp_langfuse_logger,
)
if temp_langfuse_logger is not None:
_response = temp_langfuse_logger.log_event(
kwargs=kwargs,
response_obj=result,
start_time=start_time,
end_time=end_time,
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if _response is not None and isinstance(_response, dict):
_trace_id = _response.get("trace_id", None)
if _trace_id is not None:
in_memory_trace_id_cache.set_cache(
litellm_call_id=self.litellm_call_id,
service_name="langfuse",
trace_id=_trace_id,
)
if callback == "generic":
global genericAPILogger
verbose_logger.debug("reaches langfuse for success logging!")
@ -982,7 +1044,7 @@ class Logging:
print_verbose("reaches langfuse for streaming logging!")
result = kwargs["complete_streaming_response"]
if genericAPILogger is None:
genericAPILogger = GenericAPILogger()
genericAPILogger = GenericAPILogger() # type: ignore
genericAPILogger.log_event(
kwargs=kwargs,
response_obj=result,
@ -1022,7 +1084,7 @@ class Logging:
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if callback == "greenscale":
if callback == "greenscale" and greenscaleLogger is not None:
kwargs = {}
for k, v in self.model_call_details.items():
if (
@ -1066,7 +1128,7 @@ class Logging:
result = kwargs["complete_streaming_response"]
# only add to cache once we have a complete streaming response
litellm.cache.add_cache(result, **kwargs)
if callback == "athina":
if callback == "athina" and athinaLogger is not None:
deep_copy = {}
for k, v in self.model_call_details.items():
deep_copy[k] = v
@ -1224,6 +1286,7 @@ class Logging:
"atranscription", False
)
is not True
and customLogger is not None
): # custom logger functions
print_verbose(
f"success callbacks: Running Custom Callback Function"
@ -1423,9 +1486,8 @@ class Logging:
await litellm.cache.async_add_cache(result, **kwargs)
else:
litellm.cache.add_cache(result, **kwargs)
if callback == "openmeter":
global openMeterLogger
if self.stream == True:
if callback == "openmeter" and openMeterLogger is not None:
if self.stream is True:
if (
"async_complete_streaming_response"
in self.model_call_details
@ -1645,33 +1707,9 @@ class Logging:
)
for callback in callbacks:
try:
if callback == "lite_debugger":
print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
result = {
"model": self.model,
"created": time.time(),
"error": traceback_exception,
"usage": {
"prompt_tokens": prompt_token_calculator(
self.model, messages=self.messages
),
"completion_tokens": 0,
},
}
liteDebuggerClient.log_event(
model=self.model,
messages=self.messages,
end_user=self.model_call_details.get("user", "default"),
response_obj=result,
start_time=start_time,
end_time=end_time,
litellm_call_id=self.litellm_call_id,
print_verbose=print_verbose,
call_type=self.call_type,
stream=self.stream,
)
if callback == "lunary":
if callback == "lite_debugger" and liteDebuggerClient is not None:
pass
elif callback == "lunary" and lunaryLogger is not None:
print_verbose("reaches lunary for logging error!")
model = self.model
@ -1685,6 +1723,7 @@ class Logging:
)
lunaryLogger.log_event(
kwargs=self.model_call_details,
type=_type,
event="error",
user_id=self.model_call_details.get("user", "default"),
@ -1704,22 +1743,11 @@ class Logging:
print_verbose(
f"capture exception not initialized: {capture_exception}"
)
elif callback == "supabase":
elif callback == "supabase" and supabaseClient is not None:
print_verbose("reaches supabase for logging!")
print_verbose(f"supabaseClient: {supabaseClient}")
result = {
"model": model,
"created": time.time(),
"error": traceback_exception,
"usage": {
"prompt_tokens": prompt_token_calculator(
model, messages=self.messages
),
"completion_tokens": 0,
},
}
supabaseClient.log_event(
model=self.model,
model=self.model if hasattr(self, "model") else "",
messages=self.messages,
end_user=self.model_call_details.get("user", "default"),
response_obj=result,
@ -1728,7 +1756,9 @@ class Logging:
litellm_call_id=self.model_call_details["litellm_call_id"],
print_verbose=print_verbose,
)
if callable(callback): # custom logger functions
if (
callable(callback) and customLogger is not None
): # custom logger functions
customLogger.log_event(
kwargs=self.model_call_details,
response_obj=result,
@ -1809,13 +1839,13 @@ class Logging:
start_time=start_time,
end_time=end_time,
response_obj=None,
user_id=kwargs.get("user", None),
user_id=self.model_call_details.get("user", None),
print_verbose=print_verbose,
status_message=str(exception),
level="ERROR",
kwargs=self.model_call_details,
)
if callback == "logfire":
if callback == "logfire" and logfireLogger is not None:
verbose_logger.debug("reaches logfire for failure logging!")
kwargs = {}
for k, v in self.model_call_details.items():
@ -1830,7 +1860,7 @@ class Logging:
response_obj=result,
start_time=start_time,
end_time=end_time,
level=LogfireLevel.ERROR.value,
level=LogfireLevel.ERROR.value, # type: ignore
print_verbose=print_verbose,
)
@ -1873,7 +1903,9 @@ class Logging:
start_time=start_time,
end_time=end_time,
) # type: ignore
if callable(callback): # custom logger functions
if (
callable(callback) and customLogger is not None
): # custom logger functions
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=result,
@ -1966,7 +1998,7 @@ def set_callbacks(callback_list, function_id=None):
)
sentry_sdk_instance.init(
dsn=os.environ.get("SENTRY_DSN"),
traces_sample_rate=float(sentry_trace_rate),
traces_sample_rate=float(sentry_trace_rate), # type: ignore
)
capture_exception = sentry_sdk_instance.capture_exception
add_breadcrumb = sentry_sdk_instance.add_breadcrumb
@ -2411,12 +2443,11 @@ def get_standard_logging_object_payload(
saved_cache_cost: Optional[float] = None
if cache_hit is True:
import time
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
saved_cache_cost = logging_obj._response_cost_calculator(
result=init_response_obj, cache_hit=False
result=init_response_obj, cache_hit=False # type: ignore
)
## Get model cost information ##
@ -2473,7 +2504,7 @@ def get_standard_logging_object_payload(
model_id=_model_id,
requester_ip_address=clean_metadata.get("requester_ip_address", None),
messages=kwargs.get("messages"),
response=(
response=( # type: ignore
response_obj if len(response_obj.keys()) > 0 else init_response_obj
),
model_parameters=kwargs.get("optional_params", None),