feat(utils.py): support dynamic langfuse params and team settings on proxy

This commit is contained in:
Krrish Dholakia 2024-02-01 21:08:24 -08:00
parent b79a6607b2
commit a301d8aa4b
6 changed files with 122 additions and 13 deletions

View file

@ -146,6 +146,7 @@ suppress_debug_info = False
dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None
default_key_generate_params: Optional[Dict] = None
default_team_settings: Optional[List] = None
#### RELIABILITY ####
request_timeout: Optional[float] = 6000
num_retries: Optional[int] = None # per model endpoint
@ -165,9 +166,6 @@ _key_management_system: Optional[KeyManagementSystem] = None
def get_model_cost_map(url: str):
verbose_logger.debug(
f"os.getenv('LITELLM_LOCAL_MODEL_COST_MAP', False): {os.environ['LITELLM_LOCAL_MODEL_COST_MAP']}"
)
if (
os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == True
or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True"
@ -175,7 +173,6 @@ def get_model_cost_map(url: str):
import importlib.resources
import json
verbose_logger.debug("RUNS LOCALLY")
with importlib.resources.open_text(
"litellm", "model_prices_and_context_window_backup.json"
) as f:

View file

@ -14,7 +14,7 @@ import litellm
class LangFuseLogger:
# Class variables or attributes
def __init__(self):
def __init__(self, langfuse_public_key=None, langfuse_secret=None):
try:
from langfuse import Langfuse
except Exception as e:
@ -22,8 +22,8 @@ class LangFuseLogger:
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\033[0m"
)
# Instance variables
self.secret_key = os.getenv("LANGFUSE_SECRET_KEY")
self.public_key = os.getenv("LANGFUSE_PUBLIC_KEY")
self.secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY")
self.public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY")
self.langfuse_host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
self.langfuse_release = os.getenv("LANGFUSE_RELEASE")
self.langfuse_debug = os.getenv("LANGFUSE_DEBUG")

View file

@ -321,6 +321,7 @@ class LiteLLM_VerificationToken(LiteLLMBase):
aliases: Dict = {}
config: Dict = {}
user_id: Optional[str] = None
team_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
metadata: Dict = {}
tpm_limit: Optional[int] = None

View file

@ -1024,6 +1024,24 @@ class ProxyConfig:
m["litellm_params"]["api_key"] = f"os.environ/{key_name}"
await prisma_client.insert_data(data=new_config, table_name="config")
async def load_team_config(self, team_id: str):
"""
- for a given team id
- return the relevant completion() call params
"""
all_teams_config = litellm.default_team_settings
team_config: dict = {}
if all_teams_config is None:
return team_config
for team in all_teams_config:
if team_id == team["team_id"]:
team_config = team
break
for k, v in team_config.items():
if isinstance(v, str) and v.startswith("os.environ/"):
team_config[k] = litellm.get_secret(v)
return team_config
async def load_config(
self, router: Optional[litellm.Router], config_file_path: str
):
@ -2040,6 +2058,21 @@ async def chat_completion(
data["metadata"]["headers"] = _headers
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli
if user_temperature:
@ -2215,6 +2248,21 @@ async def embeddings(
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
@ -2361,6 +2409,21 @@ async def image_generation(
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None

View file

@ -0,0 +1,36 @@
#### What this tests ####
# This tests if setting team_config actually works
import sys, os
import traceback
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.proxy.proxy_server import ProxyConfig
@pytest.mark.asyncio
async def test_team_config():
litellm.default_team_settings = [
{
"team_id": "my-special-team",
"success_callback": ["langfuse"],
"langfuse_public_key": "os.environ/LANGFUSE_PUB_KEY_2",
"langfuse_secret": "os.environ/LANGFUSE_PRIVATE_KEY_2",
}
]
proxyconfig = ProxyConfig()
team_config = await proxyconfig.load_team_config(team_id="my-special-team")
assert len(team_config) > 0
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
}
team_config.pop("team_id")
response = litellm.completion(**{**data, **team_config})
print(f"response: {response}")

View file

@ -752,6 +752,8 @@ class Logging:
function_id,
dynamic_success_callbacks=None,
dynamic_async_success_callbacks=None,
langfuse_public_key=None,
langfuse_secret=None,
):
if call_type not in [item.value for item in CallTypes]:
allowed_values = ", ".join([item.value for item in CallTypes])
@ -780,6 +782,9 @@ class Logging:
self.dynamic_async_success_callbacks = (
dynamic_async_success_callbacks or []
) # callbacks set for just that call
## DYNAMIC LANGFUSE KEYS ##
self.langfuse_public_key = langfuse_public_key
self.langfuse_secret = langfuse_secret
def update_environment_variables(
self, model, user, optional_params, litellm_params, **additional_params
@ -1211,7 +1216,9 @@ class Logging:
if "complete_streaming_response" not in kwargs:
break
else:
print_verbose("reaches langfuse for streaming logging!")
print_verbose(
"reaches langsmith for streaming logging!"
)
result = kwargs["complete_streaming_response"]
langsmithLogger.log_event(
kwargs=self.model_call_details,
@ -1279,7 +1286,10 @@ class Logging:
print_verbose("reaches langfuse for streaming logging!")
result = kwargs["complete_streaming_response"]
if langFuseLogger is None:
langFuseLogger = LangFuseLogger()
langFuseLogger = LangFuseLogger(
langfuse_public_key=self.langfuse_public_key,
langfuse_secret=self.langfuse_secret,
)
langFuseLogger.log_event(
kwargs=kwargs,
response_obj=result,
@ -1965,7 +1975,7 @@ def client(original_function):
# Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items):
kwargs["success_callback"].pop(index)
dynamic_success_callbacks = kwargs["success_callback"]
dynamic_success_callbacks = kwargs.pop("success_callback")
if add_breadcrumb:
add_breadcrumb(
@ -2030,6 +2040,8 @@ def client(original_function):
start_time=start_time,
dynamic_success_callbacks=dynamic_success_callbacks,
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
langfuse_public_key=kwargs.pop("langfuse_public_key", None),
langfuse_secret=kwargs.pop("langfuse_secret", None),
)
## check if metadata is passed in
litellm_params = {}
@ -2041,7 +2053,7 @@ def client(original_function):
optional_params={},
litellm_params=litellm_params,
)
return logging_obj
return logging_obj, kwargs
except Exception as e:
import logging
@ -2111,7 +2123,7 @@ def client(original_function):
try:
if logging_obj is None:
logging_obj = function_setup(start_time, *args, **kwargs)
logging_obj, kwargs = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj
# CHECK FOR 'os.environ/' in kwargs
@ -2346,7 +2358,7 @@ def client(original_function):
try:
if logging_obj is None:
logging_obj = function_setup(start_time, *args, **kwargs)
logging_obj, kwargs = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj
# [OPTIONAL] CHECK BUDGET