Merge pull request #5176 from BerriAI/litellm_key_logging

Allow specifying langfuse project for logging in key metadata
This commit is contained in:
Krish Dholakia 2024-08-14 12:55:07 -07:00 committed by GitHub
commit 22243c6571
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 318 additions and 58 deletions

View file

@ -5,7 +5,12 @@ from fastapi import Request
import litellm import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, TeamCallbackMetadata, UserAPIKeyAuth from litellm.proxy._types import (
AddTeamCallback,
CommonProxyErrors,
TeamCallbackMetadata,
UserAPIKeyAuth,
)
from litellm.types.utils import SupportedCacheControls from litellm.types.utils import SupportedCacheControls
if TYPE_CHECKING: if TYPE_CHECKING:
@ -59,6 +64,42 @@ def safe_add_api_version_from_query_params(data: dict, request: Request):
verbose_logger.error("error checking api version in query params: %s", str(e)) verbose_logger.error("error checking api version in query params: %s", str(e))
def convert_key_logging_metadata_to_callback(
data: AddTeamCallback, team_callback_settings_obj: Optional[TeamCallbackMetadata]
) -> TeamCallbackMetadata:
if team_callback_settings_obj is None:
team_callback_settings_obj = TeamCallbackMetadata()
if data.callback_type == "success":
if team_callback_settings_obj.success_callback is None:
team_callback_settings_obj.success_callback = []
if data.callback_name not in team_callback_settings_obj.success_callback:
team_callback_settings_obj.success_callback.append(data.callback_name)
elif data.callback_type == "failure":
if team_callback_settings_obj.failure_callback is None:
team_callback_settings_obj.failure_callback = []
if data.callback_name not in team_callback_settings_obj.failure_callback:
team_callback_settings_obj.failure_callback.append(data.callback_name)
elif data.callback_type == "success_and_failure":
if team_callback_settings_obj.success_callback is None:
team_callback_settings_obj.success_callback = []
if team_callback_settings_obj.failure_callback is None:
team_callback_settings_obj.failure_callback = []
if data.callback_name not in team_callback_settings_obj.success_callback:
team_callback_settings_obj.success_callback.append(data.callback_name)
if data.callback_name in team_callback_settings_obj.failure_callback:
team_callback_settings_obj.failure_callback.append(data.callback_name)
for var, value in data.callback_vars.items():
if team_callback_settings_obj.callback_vars is None:
team_callback_settings_obj.callback_vars = {}
team_callback_settings_obj.callback_vars[var] = litellm.get_secret(value)
return team_callback_settings_obj
async def add_litellm_data_to_request( async def add_litellm_data_to_request(
data: dict, data: dict,
request: Request, request: Request,
@ -224,6 +265,7 @@ async def add_litellm_data_to_request(
} # add the team-specific configs to the completion call } # add the team-specific configs to the completion call
# Team Callbacks controls # Team Callbacks controls
callback_settings_obj: Optional[TeamCallbackMetadata] = None
if user_api_key_dict.team_metadata is not None: if user_api_key_dict.team_metadata is not None:
team_metadata = user_api_key_dict.team_metadata team_metadata = user_api_key_dict.team_metadata
if "callback_settings" in team_metadata: if "callback_settings" in team_metadata:
@ -241,6 +283,18 @@ async def add_litellm_data_to_request(
} }
} }
""" """
elif (
user_api_key_dict.metadata is not None
and "logging" in user_api_key_dict.metadata
):
for item in user_api_key_dict.metadata["logging"]:
callback_settings_obj = convert_key_logging_metadata_to_callback(
data=AddTeamCallback(**item),
team_callback_settings_obj=callback_settings_obj,
)
if callback_settings_obj is not None:
data["success_callback"] = callback_settings_obj.success_callback data["success_callback"] = callback_settings_obj.success_callback
data["failure_callback"] = callback_settings_obj.failure_callback data["failure_callback"] = callback_settings_obj.failure_callback

View file

@ -1,23 +1,27 @@
# What is this? # What is this?
## Test to make sure function call response always works with json.loads() -> no extra parsing required. Relevant issue - https://github.com/BerriAI/litellm/issues/2654 ## Test to make sure function call response always works with json.loads() -> no extra parsing required. Relevant issue - https://github.com/BerriAI/litellm/issues/2654
import sys, os import os
import sys
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import os, io import io
import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest
import litellm
import json import json
import warnings import warnings
from litellm import completion
from typing import List from typing import List
import pytest
import litellm
from litellm import completion
# Just a stub to keep the sample code simple # Just a stub to keep the sample code simple
class Trade: class Trade:
@ -78,6 +82,7 @@ def trade(model_name: str) -> List[Trade]:
}, },
} }
try:
response = completion( response = completion(
model_name, model_name,
[ [
@ -129,7 +134,8 @@ def trade(model_name: str) -> List[Trade]:
"function": {"name": tool_spec["function"]["name"]}, # type: ignore "function": {"name": tool_spec["function"]["name"]}, # type: ignore
}, },
) )
except litellm.InternalServerError:
pass
calls = response.choices[0].message.tool_calls calls = response.choices[0].message.tool_calls
trades = [trade for call in calls for trade in parse_call(call)] trades = [trade for call in calls for trade in parse_call(call)]
return trades return trades

View file

@ -966,3 +966,203 @@ async def test_user_info_team_list(prisma_client):
pass pass
mock_client.assert_called() mock_client.assert_called()
@pytest.mark.skip(reason="Local test")
@pytest.mark.asyncio
async def test_add_callback_via_key(prisma_client):
"""
Test if callback specified in key, is used.
"""
global headers
import json
from fastapi import HTTPException, Request, Response
from starlette.datastructures import URL
from litellm.proxy.proxy_server import chat_completion
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
litellm.set_verbose = True
try:
# Your test data
test_data = {
"model": "azure/chatgpt-v-2",
"messages": [
{"role": "user", "content": "write 1 sentence poem"},
],
"max_tokens": 10,
"mock_response": "Hello world",
"api_key": "my-fake-key",
}
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
request._url = URL(url="/chat/completions")
json_bytes = json.dumps(test_data).encode("utf-8")
request._body = json_bytes
with patch.object(
litellm.litellm_core_utils.litellm_logging,
"LangFuseLogger",
new=MagicMock(),
) as mock_client:
resp = await chat_completion(
request=request,
fastapi_response=Response(),
user_api_key_dict=UserAPIKeyAuth(
metadata={
"logging": [
{
"callback_name": "langfuse", # 'otel', 'langfuse', 'lunary'
"callback_type": "success", # set, if required by integration - future improvement, have logging tools work for success + failure by default
"callback_vars": {
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
"langfuse_host": "https://us.cloud.langfuse.com",
},
}
]
}
),
)
print(resp)
mock_client.assert_called()
mock_client.return_value.log_event.assert_called()
args, kwargs = mock_client.return_value.log_event.call_args
kwargs = kwargs["kwargs"]
assert "user_api_key_metadata" in kwargs["litellm_params"]["metadata"]
assert (
"logging"
in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"]
)
checked_keys = False
for item in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"][
"logging"
]:
for k, v in item["callback_vars"].items():
print("k={}, v={}".format(k, v))
if "key" in k:
assert "os.environ" in v
checked_keys = True
assert checked_keys
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
@pytest.mark.asyncio
async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client):
import json
from fastapi import HTTPException, Request, Response
from starlette.datastructures import URL
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
request._url = URL(url="/chat/completions")
test_data = {
"model": "azure/chatgpt-v-2",
"messages": [
{"role": "user", "content": "write 1 sentence poem"},
],
"max_tokens": 10,
"mock_response": "Hello world",
"api_key": "my-fake-key",
}
json_bytes = json.dumps(test_data).encode("utf-8")
request._body = json_bytes
data = {
"data": {
"model": "azure/chatgpt-v-2",
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
"max_tokens": 10,
"mock_response": "Hello world",
"api_key": "my-fake-key",
},
"request": request,
"user_api_key_dict": UserAPIKeyAuth(
token=None,
key_name=None,
key_alias=None,
spend=0.0,
max_budget=None,
expires=None,
models=[],
aliases={},
config={},
user_id=None,
team_id=None,
max_parallel_requests=None,
metadata={
"logging": [
{
"callback_name": "langfuse",
"callback_type": "success",
"callback_vars": {
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
"langfuse_host": "https://us.cloud.langfuse.com",
},
}
]
},
tpm_limit=None,
rpm_limit=None,
budget_duration=None,
budget_reset_at=None,
allowed_cache_controls=[],
permissions={},
model_spend={},
model_max_budget={},
soft_budget_cooldown=False,
litellm_budget_table=None,
org_id=None,
team_spend=None,
team_alias=None,
team_tpm_limit=None,
team_rpm_limit=None,
team_max_budget=None,
team_models=[],
team_blocked=False,
soft_budget=None,
team_model_aliases=None,
team_member_spend=None,
team_metadata=None,
end_user_id=None,
end_user_tpm_limit=None,
end_user_rpm_limit=None,
end_user_max_budget=None,
last_refreshed_at=None,
api_key=None,
user_role=None,
allowed_model_region=None,
parent_otel_span=None,
),
"proxy_config": proxy_config,
"general_settings": {},
"version": "0.0.0",
}
new_data = await add_litellm_data_to_request(**data)
assert "success_callback" in new_data
assert new_data["success_callback"] == ["langfuse"]
assert "langfuse_public_key" in new_data
assert "langfuse_secret_key" in new_data