diff --git a/litellm/integrations/gcs_bucket.py b/litellm/integrations/gcs_bucket.py index c948668eb..46f55f8f0 100644 --- a/litellm/integrations/gcs_bucket.py +++ b/litellm/integrations/gcs_bucket.py @@ -14,7 +14,6 @@ from litellm.litellm_core_utils.logging_utils import ( ) from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload -from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload class RequestKwargs(TypedDict): @@ -30,8 +29,6 @@ class GCSBucketPayload(TypedDict): end_time: str response_cost: Optional[float] spend_log_metadata: str - response_cost: Optional[float] - spend_log_metadata: str class GCSBucketLogger(CustomLogger): @@ -136,10 +133,6 @@ class GCSBucketLogger(CustomLogger): get_logging_payload, ) - from litellm.proxy.spend_tracking.spend_tracking_utils import ( - get_logging_payload, - ) - request_kwargs = RequestKwargs( model=kwargs.get("model", None), messages=kwargs.get("messages", None), @@ -158,14 +151,6 @@ class GCSBucketLogger(CustomLogger): end_user_id=kwargs.get("end_user_id", None), ) - _spend_log_payload: SpendLogsPayload = get_logging_payload( - kwargs=kwargs, - response_obj=response_obj, - start_time=start_time, - end_time=end_time, - end_user_id=kwargs.get("end_user_id", None), - ) - gcs_payload: GCSBucketPayload = GCSBucketPayload( request_kwargs=request_kwargs, response_obj=response_dict, @@ -173,8 +158,6 @@ class GCSBucketLogger(CustomLogger): end_time=end_time, spend_log_metadata=_spend_log_payload["metadata"], response_cost=kwargs.get("response_cost", None), - spend_log_metadata=_spend_log_payload["metadata"], - response_cost=kwargs.get("response_cost", None), ) return gcs_payload diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index df4be3a5b..7a127f912 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -48,7 +48,7 @@ class LangFuseLogger: "secret_key": self.secret_key, "host": self.langfuse_host, "release": self.langfuse_release, - "debug": self.langfuse_debug, + "debug": True, "flush_interval": flush_interval, # flush interval in seconds } diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 13f9475c5..631f47692 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -5,7 +5,12 @@ from fastapi import Request import litellm from litellm._logging import verbose_logger, verbose_proxy_logger -from litellm.proxy._types import CommonProxyErrors, TeamCallbackMetadata, UserAPIKeyAuth +from litellm.proxy._types import ( + AddTeamCallback, + CommonProxyErrors, + TeamCallbackMetadata, + UserAPIKeyAuth, +) from litellm.types.utils import SupportedCacheControls if TYPE_CHECKING: @@ -59,6 +64,42 @@ def safe_add_api_version_from_query_params(data: dict, request: Request): verbose_logger.error("error checking api version in query params: %s", str(e)) +def convert_key_logging_metadata_to_callback( + data: AddTeamCallback, team_callback_settings_obj: Optional[TeamCallbackMetadata] +) -> TeamCallbackMetadata: + if team_callback_settings_obj is None: + team_callback_settings_obj = TeamCallbackMetadata() + if data.callback_type == "success": + if team_callback_settings_obj.success_callback is None: + team_callback_settings_obj.success_callback = [] + + if data.callback_name not in team_callback_settings_obj.success_callback: + team_callback_settings_obj.success_callback.append(data.callback_name) + elif data.callback_type == "failure": + if team_callback_settings_obj.failure_callback is None: + team_callback_settings_obj.failure_callback = [] + + if data.callback_name not in team_callback_settings_obj.failure_callback: + team_callback_settings_obj.failure_callback.append(data.callback_name) + elif data.callback_type == "success_and_failure": + if team_callback_settings_obj.success_callback is None: + team_callback_settings_obj.success_callback = [] + if team_callback_settings_obj.failure_callback is None: + team_callback_settings_obj.failure_callback = [] + if data.callback_name not in team_callback_settings_obj.success_callback: + team_callback_settings_obj.success_callback.append(data.callback_name) + + if data.callback_name in team_callback_settings_obj.failure_callback: + team_callback_settings_obj.failure_callback.append(data.callback_name) + + for var, value in data.callback_vars.items(): + if team_callback_settings_obj.callback_vars is None: + team_callback_settings_obj.callback_vars = {} + team_callback_settings_obj.callback_vars[var] = litellm.get_secret(value) + + return team_callback_settings_obj + + async def add_litellm_data_to_request( data: dict, request: Request, @@ -214,6 +255,7 @@ async def add_litellm_data_to_request( } # add the team-specific configs to the completion call # Team Callbacks controls + callback_settings_obj: Optional[TeamCallbackMetadata] = None if user_api_key_dict.team_metadata is not None: team_metadata = user_api_key_dict.team_metadata if "callback_settings" in team_metadata: @@ -231,13 +273,25 @@ async def add_litellm_data_to_request( } } """ - data["success_callback"] = callback_settings_obj.success_callback - data["failure_callback"] = callback_settings_obj.failure_callback + elif ( + user_api_key_dict.metadata is not None + and "logging" in user_api_key_dict.metadata + ): + for item in user_api_key_dict.metadata["logging"]: - if callback_settings_obj.callback_vars is not None: - # unpack callback_vars in data - for k, v in callback_settings_obj.callback_vars.items(): - data[k] = v + callback_settings_obj = convert_key_logging_metadata_to_callback( + data=AddTeamCallback(**item), + team_callback_settings_obj=callback_settings_obj, + ) + + if callback_settings_obj is not None: + data["success_callback"] = callback_settings_obj.success_callback + data["failure_callback"] = callback_settings_obj.failure_callback + + if callback_settings_obj.callback_vars is not None: + # unpack callback_vars in data + for k, v in callback_settings_obj.callback_vars.items(): + data[k] = v return data diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 757eef6d6..890446e56 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -966,3 +966,92 @@ async def test_user_info_team_list(prisma_client): pass mock_client.assert_called() + + +@pytest.mark.asyncio +async def test_add_callback_via_key(prisma_client): + """ + Test if callback specified in key, is used. + """ + global headers + import json + + from fastapi import HTTPException, Request, Response + from starlette.datastructures import URL + + from litellm.proxy.proxy_server import chat_completion + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + + litellm.set_verbose = True + + try: + # Your test data + test_data = { + "model": "azure/chatgpt-v-2", + "messages": [ + {"role": "user", "content": "write 1 sentence poem"}, + ], + "max_tokens": 10, + "mock_response": "Hello world", + "api_key": "my-fake-key", + } + + request = Request(scope={"type": "http", "method": "POST", "headers": {}}) + request._url = URL(url="/chat/completions") + + json_bytes = json.dumps(test_data).encode("utf-8") + + request._body = json_bytes + + with patch.object( + litellm.litellm_core_utils.litellm_logging, + "LangFuseLogger", + new=MagicMock(), + ) as mock_client: + resp = await chat_completion( + request=request, + fastapi_response=Response(), + user_api_key_dict=UserAPIKeyAuth( + metadata={ + "logging": [ + { + "callback_name": "langfuse", # 'otel', 'langfuse', 'lunary' + "callback_type": "success", # set, if required by integration - future improvement, have logging tools work for success + failure by default + "callback_vars": { + "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY", + "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", + "langfuse_host": "https://us.cloud.langfuse.com", + }, + } + ] + } + ), + ) + print(resp) + mock_client.assert_called() + mock_client.return_value.log_event.assert_called() + args, kwargs = mock_client.return_value.log_event.call_args + print("KWARGS - {}".format(kwargs)) + kwargs = kwargs["kwargs"] + print(kwargs) + assert "user_api_key_metadata" in kwargs["litellm_params"]["metadata"] + assert ( + "logging" + in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"] + ) + checked_keys = False + for item in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"][ + "logging" + ]: + for k, v in item["callback_vars"].items(): + print("k={}, v={}".format(k, v)) + if "key" in k: + assert "os.environ" in v + checked_keys = True + + assert checked_keys + except Exception as e: + pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")