fix(proxy/utils.py): support logging rejected requests to langfuse, etc.

This commit is contained in:
Krrish Dholakia 2024-07-05 14:39:35 -07:00
parent d528b66db0
commit b1b21b0340
4 changed files with 36 additions and 52 deletions

View file

@ -153,11 +153,6 @@ class Logging:
langfuse_secret=None, langfuse_secret=None,
langfuse_host=None, langfuse_host=None,
): ):
if call_type not in [item.value for item in CallTypes]:
allowed_values = ", ".join([item.value for item in CallTypes])
raise ValueError(
f"Invalid call_type {call_type}. Allowed values: {allowed_values}"
)
if messages is not None: if messages is not None:
if isinstance(messages, str): if isinstance(messages, str):
messages = [ messages = [

View file

@ -49,6 +49,7 @@ from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.parallel_request_limiter import ( from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler, _PROXY_MaxParallelRequestsHandler,
) )
from litellm.types.utils import CallTypes
if TYPE_CHECKING: if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span from opentelemetry.trace import Span as _Span
@ -354,35 +355,6 @@ class ProxyLogging:
print_verbose(f"final data being sent to {call_type} call: {data}") print_verbose(f"final data being sent to {call_type} call: {data}")
return data return data
except Exception as e: except Exception as e:
if "litellm_logging_obj" in data:
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging = data[
"litellm_logging_obj"
]
## ASYNC FAILURE HANDLER ##
error_message = ""
if isinstance(e, HTTPException):
if isinstance(e.detail, str):
error_message = e.detail
elif isinstance(e.detail, dict):
error_message = json.dumps(e.detail)
else:
error_message = str(e)
else:
error_message = str(e)
error_raised = Exception(f"{error_message}")
await logging_obj.async_failure_handler(
exception=error_raised,
traceback_exception=traceback.format_exc(),
)
## SYNC FAILURE HANDLER ##
try:
logging_obj.failure_handler(
error_raised, traceback.format_exc()
) # DO NOT MAKE THREADED - router retry fallback relies on this!
except Exception as error_val:
pass
raise e raise e
async def during_call_hook( async def during_call_hook(
@ -597,12 +569,14 @@ class ProxyLogging:
) )
### LOGGING ### ### LOGGING ###
litellm_logging_obj: Optional[Logging] = request_data.get(
"litellm_logging_obj", None
)
if isinstance(original_exception, HTTPException): if isinstance(original_exception, HTTPException):
litellm_logging_obj: Optional[Logging] = request_data.get(
"litellm_logging_obj", None
)
if litellm_logging_obj is None: if litellm_logging_obj is None:
import uuid
request_data["litellm_call_id"] = str(uuid.uuid4())
litellm_logging_obj, data = litellm.utils.function_setup( litellm_logging_obj, data = litellm.utils.function_setup(
original_function="IGNORE_THIS", original_function="IGNORE_THIS",
rules_obj=litellm.utils.Rules(), rules_obj=litellm.utils.Rules(),

View file

@ -23,6 +23,8 @@ import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from typing import Literal
import pytest import pytest
from fastapi import Request, Response from fastapi import Request, Response
from starlette.datastructures import URL from starlette.datastructures import URL
@ -51,7 +53,20 @@ class testLogger(CustomLogger):
def __init__(self): def __init__(self):
self.reaches_failure_event = False self.reaches_failure_event = False
async def async_pre_call_check(self, deployment: dict): async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
raise HTTPException( raise HTTPException(
status_code=429, detail={"error": "Max parallel request limit reached"} status_code=429, detail={"error": "Max parallel request limit reached"}
) )
@ -92,15 +107,15 @@ router = Router(
], ],
}, },
), ),
# ("/v1/completions", {"model": "fake-model", "prompt": "ping"}), ("/v1/completions", {"model": "fake-model", "prompt": "ping"}),
# ( (
# "/v1/embeddings", "/v1/embeddings",
# { {
# "input": "The food was delicious and the waiter...", "input": "The food was delicious and the waiter...",
# "model": "text-embedding-ada-002", "model": "text-embedding-ada-002",
# "encoding_format": "float", "encoding_format": "float",
# }, },
# ), ),
], ],
) )
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -531,6 +531,8 @@ def function_setup(
call_type == CallTypes.aspeech.value or call_type == CallTypes.speech.value call_type == CallTypes.aspeech.value or call_type == CallTypes.speech.value
): ):
messages = kwargs.get("input", "speech") messages = kwargs.get("input", "speech")
else:
messages = "default-message-value"
stream = True if "stream" in kwargs and kwargs["stream"] == True else False stream = True if "stream" in kwargs and kwargs["stream"] == True else False
logging_obj = litellm.litellm_core_utils.litellm_logging.Logging( logging_obj = litellm.litellm_core_utils.litellm_logging.Logging(
model=model, model=model,
@ -561,10 +563,8 @@ def function_setup(
) )
return logging_obj, kwargs return logging_obj, kwargs
except Exception as e: except Exception as e:
import logging verbose_logger.error(
f"litellm.utils.py::function_setup() - [Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}"
logging.debug(
f"[Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}"
) )
raise e raise e