forked from phoenix/litellm-mirror
* LiteLLM Minor Fixes & Improvements (09/16/2024) (#5723) * coverage (#5713) Signed-off-by: dbczumar <corey.zumar@databricks.com> * Move (#5714) Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix(litellm_logging.py): fix logging client re-init (#5710) Fixes https://github.com/BerriAI/litellm/issues/5695 * fix(presidio.py): Fix logging_hook response and add support for additional presidio variables in guardrails config Fixes https://github.com/BerriAI/litellm/issues/5682 * feat(o1_handler.py): fake streaming for openai o1 models Fixes https://github.com/BerriAI/litellm/issues/5694 * docs: deprecated traceloop integration in favor of native otel (#5249) * fix: fix linting errors * fix: fix linting errors * fix(main.py): fix o1 import --------- Signed-off-by: dbczumar <corey.zumar@databricks.com> Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit <nirga@users.noreply.github.com> * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view (#5730) * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view Supports having `MonthlyGlobalSpend` view be a material view, and exposes an endpoint to refresh it * fix(custom_logger.py): reset calltype * fix: fix linting errors * fix: fix linting error * fix: fix import * test(test_databricks.py): fix databricks tests --------- Signed-off-by: dbczumar <corey.zumar@databricks.com> Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>
This commit is contained in:
parent
1e59395280
commit
234185ec13
34 changed files with 1387 additions and 502 deletions
|
@ -90,6 +90,13 @@ from ..integrations.supabase import Supabase
|
|||
from ..integrations.traceloop import TraceloopLogger
|
||||
from ..integrations.weights_biases import WeightsBiasesLogger
|
||||
|
||||
try:
|
||||
from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import (
|
||||
GenericAPILogger,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
|
||||
|
||||
_in_memory_loggers: List[Any] = []
|
||||
|
||||
### GLOBAL VARIABLES ###
|
||||
|
@ -145,7 +152,41 @@ class ServiceTraceIDCache:
|
|||
return None
|
||||
|
||||
|
||||
import hashlib
|
||||
|
||||
|
||||
class DynamicLoggingCache:
|
||||
"""
|
||||
Prevent memory leaks caused by initializing new logging clients on each request.
|
||||
|
||||
Relevant Issue: https://github.com/BerriAI/litellm/issues/5695
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.cache = InMemoryCache()
|
||||
|
||||
def get_cache_key(self, args: dict) -> str:
|
||||
args_str = json.dumps(args, sort_keys=True)
|
||||
cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
|
||||
def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]:
|
||||
key_name = self.get_cache_key(
|
||||
args={**credentials, "service_name": service_name}
|
||||
)
|
||||
response = self.cache.get_cache(key=key_name)
|
||||
return response
|
||||
|
||||
def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None:
|
||||
key_name = self.get_cache_key(
|
||||
args={**credentials, "service_name": service_name}
|
||||
)
|
||||
self.cache.set_cache(key=key_name, value=logging_obj)
|
||||
return None
|
||||
|
||||
|
||||
in_memory_trace_id_cache = ServiceTraceIDCache()
|
||||
in_memory_dynamic_logger_cache = DynamicLoggingCache()
|
||||
|
||||
|
||||
class Logging:
|
||||
|
@ -324,10 +365,10 @@ class Logging:
|
|||
print_verbose(f"\033[92m{curl_command}\033[0m\n", log_level="DEBUG")
|
||||
# log raw request to provider (like LangFuse) -- if opted in.
|
||||
if log_raw_request_response is True:
|
||||
_litellm_params = self.model_call_details.get("litellm_params", {})
|
||||
_metadata = _litellm_params.get("metadata", {}) or {}
|
||||
try:
|
||||
# [Non-blocking Extra Debug Information in metadata]
|
||||
_litellm_params = self.model_call_details.get("litellm_params", {})
|
||||
_metadata = _litellm_params.get("metadata", {}) or {}
|
||||
if (
|
||||
turn_off_message_logging is not None
|
||||
and turn_off_message_logging is True
|
||||
|
@ -362,7 +403,7 @@ class Logging:
|
|||
callbacks = litellm.input_callback + self.dynamic_input_callbacks
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if callback == "supabase":
|
||||
if callback == "supabase" and supabaseClient is not None:
|
||||
verbose_logger.debug("reaches supabase for logging!")
|
||||
model = self.model_call_details["model"]
|
||||
messages = self.model_call_details["input"]
|
||||
|
@ -396,7 +437,9 @@ class Logging:
|
|||
messages=self.messages,
|
||||
kwargs=self.model_call_details,
|
||||
)
|
||||
elif callable(callback): # custom logger functions
|
||||
elif (
|
||||
callable(callback) and customLogger is not None
|
||||
): # custom logger functions
|
||||
customLogger.log_input_event(
|
||||
model=self.model,
|
||||
messages=self.messages,
|
||||
|
@ -615,7 +658,7 @@ class Logging:
|
|||
|
||||
self.model_call_details["litellm_params"]["metadata"][
|
||||
"hidden_params"
|
||||
] = result._hidden_params
|
||||
] = getattr(result, "_hidden_params", {})
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
|
@ -645,6 +688,7 @@ class Logging:
|
|||
litellm.max_budget
|
||||
and self.stream is False
|
||||
and result is not None
|
||||
and isinstance(result, dict)
|
||||
and "content" in result
|
||||
):
|
||||
time_diff = (end_time - start_time).total_seconds()
|
||||
|
@ -652,7 +696,7 @@ class Logging:
|
|||
litellm._current_cost += litellm.completion_cost(
|
||||
model=self.model,
|
||||
prompt="",
|
||||
completion=result["content"],
|
||||
completion=getattr(result, "content", ""),
|
||||
total_time=float_diff,
|
||||
)
|
||||
|
||||
|
@ -758,7 +802,7 @@ class Logging:
|
|||
):
|
||||
print_verbose("no-log request, skipping logging")
|
||||
continue
|
||||
if callback == "lite_debugger":
|
||||
if callback == "lite_debugger" and liteDebuggerClient is not None:
|
||||
print_verbose("reaches lite_debugger for logging!")
|
||||
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||
print_verbose(
|
||||
|
@ -774,7 +818,7 @@ class Logging:
|
|||
call_type=self.call_type,
|
||||
stream=self.stream,
|
||||
)
|
||||
if callback == "promptlayer":
|
||||
if callback == "promptlayer" and promptLayerLogger is not None:
|
||||
print_verbose("reaches promptlayer for logging!")
|
||||
promptLayerLogger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
|
@ -783,7 +827,7 @@ class Logging:
|
|||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callback == "supabase":
|
||||
if callback == "supabase" and supabaseClient is not None:
|
||||
print_verbose("reaches supabase for logging!")
|
||||
kwargs = self.model_call_details
|
||||
|
||||
|
@ -811,7 +855,7 @@ class Logging:
|
|||
),
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callback == "wandb":
|
||||
if callback == "wandb" and weightsBiasesLogger is not None:
|
||||
print_verbose("reaches wandb for logging!")
|
||||
weightsBiasesLogger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
|
@ -820,8 +864,7 @@ class Logging:
|
|||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callback == "logfire":
|
||||
global logfireLogger
|
||||
if callback == "logfire" and logfireLogger is not None:
|
||||
verbose_logger.debug("reaches logfire for success logging!")
|
||||
kwargs = {}
|
||||
for k, v in self.model_call_details.items():
|
||||
|
@ -844,10 +887,10 @@ class Logging:
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
level=LogfireLevel.INFO.value,
|
||||
level=LogfireLevel.INFO.value, # type: ignore
|
||||
)
|
||||
|
||||
if callback == "lunary":
|
||||
if callback == "lunary" and lunaryLogger is not None:
|
||||
print_verbose("reaches lunary for logging!")
|
||||
model = self.model
|
||||
kwargs = self.model_call_details
|
||||
|
@ -882,7 +925,7 @@ class Logging:
|
|||
run_id=self.litellm_call_id,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callback == "helicone":
|
||||
if callback == "helicone" and heliconeLogger is not None:
|
||||
print_verbose("reaches helicone for logging!")
|
||||
model = self.model
|
||||
messages = self.model_call_details["input"]
|
||||
|
@ -924,6 +967,7 @@ class Logging:
|
|||
else:
|
||||
print_verbose("reaches langfuse for streaming logging!")
|
||||
result = kwargs["complete_streaming_response"]
|
||||
|
||||
temp_langfuse_logger = langFuseLogger
|
||||
if langFuseLogger is None or (
|
||||
(
|
||||
|
@ -941,27 +985,45 @@ class Logging:
|
|||
and self.langfuse_host != langFuseLogger.langfuse_host
|
||||
)
|
||||
):
|
||||
temp_langfuse_logger = LangFuseLogger(
|
||||
langfuse_public_key=self.langfuse_public_key,
|
||||
langfuse_secret=self.langfuse_secret,
|
||||
langfuse_host=self.langfuse_host,
|
||||
)
|
||||
_response = temp_langfuse_logger.log_event(
|
||||
kwargs=kwargs,
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
user_id=kwargs.get("user", None),
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if _response is not None and isinstance(_response, dict):
|
||||
_trace_id = _response.get("trace_id", None)
|
||||
if _trace_id is not None:
|
||||
in_memory_trace_id_cache.set_cache(
|
||||
litellm_call_id=self.litellm_call_id,
|
||||
service_name="langfuse",
|
||||
trace_id=_trace_id,
|
||||
credentials = {
|
||||
"langfuse_public_key": self.langfuse_public_key,
|
||||
"langfuse_secret": self.langfuse_secret,
|
||||
"langfuse_host": self.langfuse_host,
|
||||
}
|
||||
temp_langfuse_logger = (
|
||||
in_memory_dynamic_logger_cache.get_cache(
|
||||
credentials=credentials, service_name="langfuse"
|
||||
)
|
||||
)
|
||||
if temp_langfuse_logger is None:
|
||||
temp_langfuse_logger = LangFuseLogger(
|
||||
langfuse_public_key=self.langfuse_public_key,
|
||||
langfuse_secret=self.langfuse_secret,
|
||||
langfuse_host=self.langfuse_host,
|
||||
)
|
||||
in_memory_dynamic_logger_cache.set_cache(
|
||||
credentials=credentials,
|
||||
service_name="langfuse",
|
||||
logging_obj=temp_langfuse_logger,
|
||||
)
|
||||
|
||||
if temp_langfuse_logger is not None:
|
||||
_response = temp_langfuse_logger.log_event(
|
||||
kwargs=kwargs,
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
user_id=kwargs.get("user", None),
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if _response is not None and isinstance(_response, dict):
|
||||
_trace_id = _response.get("trace_id", None)
|
||||
if _trace_id is not None:
|
||||
in_memory_trace_id_cache.set_cache(
|
||||
litellm_call_id=self.litellm_call_id,
|
||||
service_name="langfuse",
|
||||
trace_id=_trace_id,
|
||||
)
|
||||
if callback == "generic":
|
||||
global genericAPILogger
|
||||
verbose_logger.debug("reaches langfuse for success logging!")
|
||||
|
@ -982,7 +1044,7 @@ class Logging:
|
|||
print_verbose("reaches langfuse for streaming logging!")
|
||||
result = kwargs["complete_streaming_response"]
|
||||
if genericAPILogger is None:
|
||||
genericAPILogger = GenericAPILogger()
|
||||
genericAPILogger = GenericAPILogger() # type: ignore
|
||||
genericAPILogger.log_event(
|
||||
kwargs=kwargs,
|
||||
response_obj=result,
|
||||
|
@ -1022,7 +1084,7 @@ class Logging:
|
|||
user_id=kwargs.get("user", None),
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callback == "greenscale":
|
||||
if callback == "greenscale" and greenscaleLogger is not None:
|
||||
kwargs = {}
|
||||
for k, v in self.model_call_details.items():
|
||||
if (
|
||||
|
@ -1066,7 +1128,7 @@ class Logging:
|
|||
result = kwargs["complete_streaming_response"]
|
||||
# only add to cache once we have a complete streaming response
|
||||
litellm.cache.add_cache(result, **kwargs)
|
||||
if callback == "athina":
|
||||
if callback == "athina" and athinaLogger is not None:
|
||||
deep_copy = {}
|
||||
for k, v in self.model_call_details.items():
|
||||
deep_copy[k] = v
|
||||
|
@ -1224,6 +1286,7 @@ class Logging:
|
|||
"atranscription", False
|
||||
)
|
||||
is not True
|
||||
and customLogger is not None
|
||||
): # custom logger functions
|
||||
print_verbose(
|
||||
f"success callbacks: Running Custom Callback Function"
|
||||
|
@ -1423,9 +1486,8 @@ class Logging:
|
|||
await litellm.cache.async_add_cache(result, **kwargs)
|
||||
else:
|
||||
litellm.cache.add_cache(result, **kwargs)
|
||||
if callback == "openmeter":
|
||||
global openMeterLogger
|
||||
if self.stream == True:
|
||||
if callback == "openmeter" and openMeterLogger is not None:
|
||||
if self.stream is True:
|
||||
if (
|
||||
"async_complete_streaming_response"
|
||||
in self.model_call_details
|
||||
|
@ -1645,33 +1707,9 @@ class Logging:
|
|||
)
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if callback == "lite_debugger":
|
||||
print_verbose("reaches lite_debugger for logging!")
|
||||
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||
result = {
|
||||
"model": self.model,
|
||||
"created": time.time(),
|
||||
"error": traceback_exception,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_token_calculator(
|
||||
self.model, messages=self.messages
|
||||
),
|
||||
"completion_tokens": 0,
|
||||
},
|
||||
}
|
||||
liteDebuggerClient.log_event(
|
||||
model=self.model,
|
||||
messages=self.messages,
|
||||
end_user=self.model_call_details.get("user", "default"),
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
litellm_call_id=self.litellm_call_id,
|
||||
print_verbose=print_verbose,
|
||||
call_type=self.call_type,
|
||||
stream=self.stream,
|
||||
)
|
||||
if callback == "lunary":
|
||||
if callback == "lite_debugger" and liteDebuggerClient is not None:
|
||||
pass
|
||||
elif callback == "lunary" and lunaryLogger is not None:
|
||||
print_verbose("reaches lunary for logging error!")
|
||||
|
||||
model = self.model
|
||||
|
@ -1685,6 +1723,7 @@ class Logging:
|
|||
)
|
||||
|
||||
lunaryLogger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
type=_type,
|
||||
event="error",
|
||||
user_id=self.model_call_details.get("user", "default"),
|
||||
|
@ -1704,22 +1743,11 @@ class Logging:
|
|||
print_verbose(
|
||||
f"capture exception not initialized: {capture_exception}"
|
||||
)
|
||||
elif callback == "supabase":
|
||||
elif callback == "supabase" and supabaseClient is not None:
|
||||
print_verbose("reaches supabase for logging!")
|
||||
print_verbose(f"supabaseClient: {supabaseClient}")
|
||||
result = {
|
||||
"model": model,
|
||||
"created": time.time(),
|
||||
"error": traceback_exception,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_token_calculator(
|
||||
model, messages=self.messages
|
||||
),
|
||||
"completion_tokens": 0,
|
||||
},
|
||||
}
|
||||
supabaseClient.log_event(
|
||||
model=self.model,
|
||||
model=self.model if hasattr(self, "model") else "",
|
||||
messages=self.messages,
|
||||
end_user=self.model_call_details.get("user", "default"),
|
||||
response_obj=result,
|
||||
|
@ -1728,7 +1756,9 @@ class Logging:
|
|||
litellm_call_id=self.model_call_details["litellm_call_id"],
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callable(callback): # custom logger functions
|
||||
if (
|
||||
callable(callback) and customLogger is not None
|
||||
): # custom logger functions
|
||||
customLogger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=result,
|
||||
|
@ -1809,13 +1839,13 @@ class Logging:
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
response_obj=None,
|
||||
user_id=kwargs.get("user", None),
|
||||
user_id=self.model_call_details.get("user", None),
|
||||
print_verbose=print_verbose,
|
||||
status_message=str(exception),
|
||||
level="ERROR",
|
||||
kwargs=self.model_call_details,
|
||||
)
|
||||
if callback == "logfire":
|
||||
if callback == "logfire" and logfireLogger is not None:
|
||||
verbose_logger.debug("reaches logfire for failure logging!")
|
||||
kwargs = {}
|
||||
for k, v in self.model_call_details.items():
|
||||
|
@ -1830,7 +1860,7 @@ class Logging:
|
|||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
level=LogfireLevel.ERROR.value,
|
||||
level=LogfireLevel.ERROR.value, # type: ignore
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
|
||||
|
@ -1873,7 +1903,9 @@ class Logging:
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
) # type: ignore
|
||||
if callable(callback): # custom logger functions
|
||||
if (
|
||||
callable(callback) and customLogger is not None
|
||||
): # custom logger functions
|
||||
await customLogger.async_log_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=result,
|
||||
|
@ -1966,7 +1998,7 @@ def set_callbacks(callback_list, function_id=None):
|
|||
)
|
||||
sentry_sdk_instance.init(
|
||||
dsn=os.environ.get("SENTRY_DSN"),
|
||||
traces_sample_rate=float(sentry_trace_rate),
|
||||
traces_sample_rate=float(sentry_trace_rate), # type: ignore
|
||||
)
|
||||
capture_exception = sentry_sdk_instance.capture_exception
|
||||
add_breadcrumb = sentry_sdk_instance.add_breadcrumb
|
||||
|
@ -2411,12 +2443,11 @@ def get_standard_logging_object_payload(
|
|||
|
||||
saved_cache_cost: Optional[float] = None
|
||||
if cache_hit is True:
|
||||
import time
|
||||
|
||||
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
||||
|
||||
saved_cache_cost = logging_obj._response_cost_calculator(
|
||||
result=init_response_obj, cache_hit=False
|
||||
result=init_response_obj, cache_hit=False # type: ignore
|
||||
)
|
||||
|
||||
## Get model cost information ##
|
||||
|
@ -2473,7 +2504,7 @@ def get_standard_logging_object_payload(
|
|||
model_id=_model_id,
|
||||
requester_ip_address=clean_metadata.get("requester_ip_address", None),
|
||||
messages=kwargs.get("messages"),
|
||||
response=(
|
||||
response=( # type: ignore
|
||||
response_obj if len(response_obj.keys()) > 0 else init_response_obj
|
||||
),
|
||||
model_parameters=kwargs.get("optional_params", None),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue