fix(litellm_pre_call_utils.py): support routing to logging project by api key

This commit is contained in:
Krrish Dholakia 2024-08-12 21:21:40 -07:00
parent f322ffc413
commit 93a1335e46
4 changed files with 151 additions and 25 deletions

View file

@ -14,7 +14,6 @@ from litellm.litellm_core_utils.logging_utils import (
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload
from litellm.proxy._types import CommonProxyErrors, SpendLogsMetadata, SpendLogsPayload
class RequestKwargs(TypedDict):
@ -30,8 +29,6 @@ class GCSBucketPayload(TypedDict):
end_time: str
response_cost: Optional[float]
spend_log_metadata: str
response_cost: Optional[float]
spend_log_metadata: str
class GCSBucketLogger(CustomLogger):
@ -136,10 +133,6 @@ class GCSBucketLogger(CustomLogger):
get_logging_payload,
)
from litellm.proxy.spend_tracking.spend_tracking_utils import (
get_logging_payload,
)
request_kwargs = RequestKwargs(
model=kwargs.get("model", None),
messages=kwargs.get("messages", None),
@ -158,14 +151,6 @@ class GCSBucketLogger(CustomLogger):
end_user_id=kwargs.get("end_user_id", None),
)
_spend_log_payload: SpendLogsPayload = get_logging_payload(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
end_user_id=kwargs.get("end_user_id", None),
)
gcs_payload: GCSBucketPayload = GCSBucketPayload(
request_kwargs=request_kwargs,
response_obj=response_dict,
@ -173,8 +158,6 @@ class GCSBucketLogger(CustomLogger):
end_time=end_time,
spend_log_metadata=_spend_log_payload["metadata"],
response_cost=kwargs.get("response_cost", None),
spend_log_metadata=_spend_log_payload["metadata"],
response_cost=kwargs.get("response_cost", None),
)
return gcs_payload

View file

@ -48,7 +48,7 @@ class LangFuseLogger:
"secret_key": self.secret_key,
"host": self.langfuse_host,
"release": self.langfuse_release,
"debug": self.langfuse_debug,
"debug": True,
"flush_interval": flush_interval, # flush interval in seconds
}

View file

@ -5,7 +5,12 @@ from fastapi import Request
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, TeamCallbackMetadata, UserAPIKeyAuth
from litellm.proxy._types import (
AddTeamCallback,
CommonProxyErrors,
TeamCallbackMetadata,
UserAPIKeyAuth,
)
from litellm.types.utils import SupportedCacheControls
if TYPE_CHECKING:
@ -59,6 +64,42 @@ def safe_add_api_version_from_query_params(data: dict, request: Request):
verbose_logger.error("error checking api version in query params: %s", str(e))
def convert_key_logging_metadata_to_callback(
data: AddTeamCallback, team_callback_settings_obj: Optional[TeamCallbackMetadata]
) -> TeamCallbackMetadata:
if team_callback_settings_obj is None:
team_callback_settings_obj = TeamCallbackMetadata()
if data.callback_type == "success":
if team_callback_settings_obj.success_callback is None:
team_callback_settings_obj.success_callback = []
if data.callback_name not in team_callback_settings_obj.success_callback:
team_callback_settings_obj.success_callback.append(data.callback_name)
elif data.callback_type == "failure":
if team_callback_settings_obj.failure_callback is None:
team_callback_settings_obj.failure_callback = []
if data.callback_name not in team_callback_settings_obj.failure_callback:
team_callback_settings_obj.failure_callback.append(data.callback_name)
elif data.callback_type == "success_and_failure":
if team_callback_settings_obj.success_callback is None:
team_callback_settings_obj.success_callback = []
if team_callback_settings_obj.failure_callback is None:
team_callback_settings_obj.failure_callback = []
if data.callback_name not in team_callback_settings_obj.success_callback:
team_callback_settings_obj.success_callback.append(data.callback_name)
if data.callback_name in team_callback_settings_obj.failure_callback:
team_callback_settings_obj.failure_callback.append(data.callback_name)
for var, value in data.callback_vars.items():
if team_callback_settings_obj.callback_vars is None:
team_callback_settings_obj.callback_vars = {}
team_callback_settings_obj.callback_vars[var] = litellm.get_secret(value)
return team_callback_settings_obj
async def add_litellm_data_to_request(
data: dict,
request: Request,
@ -214,6 +255,7 @@ async def add_litellm_data_to_request(
} # add the team-specific configs to the completion call
# Team Callbacks controls
callback_settings_obj: Optional[TeamCallbackMetadata] = None
if user_api_key_dict.team_metadata is not None:
team_metadata = user_api_key_dict.team_metadata
if "callback_settings" in team_metadata:
@ -231,6 +273,18 @@ async def add_litellm_data_to_request(
}
}
"""
elif (
user_api_key_dict.metadata is not None
and "logging" in user_api_key_dict.metadata
):
for item in user_api_key_dict.metadata["logging"]:
callback_settings_obj = convert_key_logging_metadata_to_callback(
data=AddTeamCallback(**item),
team_callback_settings_obj=callback_settings_obj,
)
if callback_settings_obj is not None:
data["success_callback"] = callback_settings_obj.success_callback
data["failure_callback"] = callback_settings_obj.failure_callback

View file

@ -966,3 +966,92 @@ async def test_user_info_team_list(prisma_client):
pass
mock_client.assert_called()
@pytest.mark.asyncio
async def test_add_callback_via_key(prisma_client):
"""
Test if callback specified in key, is used.
"""
global headers
import json
from fastapi import HTTPException, Request, Response
from starlette.datastructures import URL
from litellm.proxy.proxy_server import chat_completion
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
litellm.set_verbose = True
try:
# Your test data
test_data = {
"model": "azure/chatgpt-v-2",
"messages": [
{"role": "user", "content": "write 1 sentence poem"},
],
"max_tokens": 10,
"mock_response": "Hello world",
"api_key": "my-fake-key",
}
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
request._url = URL(url="/chat/completions")
json_bytes = json.dumps(test_data).encode("utf-8")
request._body = json_bytes
with patch.object(
litellm.litellm_core_utils.litellm_logging,
"LangFuseLogger",
new=MagicMock(),
) as mock_client:
resp = await chat_completion(
request=request,
fastapi_response=Response(),
user_api_key_dict=UserAPIKeyAuth(
metadata={
"logging": [
{
"callback_name": "langfuse", # 'otel', 'langfuse', 'lunary'
"callback_type": "success", # set, if required by integration - future improvement, have logging tools work for success + failure by default
"callback_vars": {
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
"langfuse_host": "https://us.cloud.langfuse.com",
},
}
]
}
),
)
print(resp)
mock_client.assert_called()
mock_client.return_value.log_event.assert_called()
args, kwargs = mock_client.return_value.log_event.call_args
print("KWARGS - {}".format(kwargs))
kwargs = kwargs["kwargs"]
print(kwargs)
assert "user_api_key_metadata" in kwargs["litellm_params"]["metadata"]
assert (
"logging"
in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"]
)
checked_keys = False
for item in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"][
"logging"
]:
for k, v in item["callback_vars"].items():
print("k={}, v={}".format(k, v))
if "key" in k:
assert "os.environ" in v
checked_keys = True
assert checked_keys
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")