feat(utils.py): support dynamic langfuse params and team settings on proxy

This commit is contained in:
Krrish Dholakia 2024-02-01 21:08:24 -08:00
parent b79a6607b2
commit a301d8aa4b
6 changed files with 122 additions and 13 deletions

View file

@ -146,6 +146,7 @@ suppress_debug_info = False
dynamodb_table_name: Optional[str] = None dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None s3_callback_params: Optional[Dict] = None
default_key_generate_params: Optional[Dict] = None default_key_generate_params: Optional[Dict] = None
default_team_settings: Optional[List] = None
#### RELIABILITY #### #### RELIABILITY ####
request_timeout: Optional[float] = 6000 request_timeout: Optional[float] = 6000
num_retries: Optional[int] = None # per model endpoint num_retries: Optional[int] = None # per model endpoint
@ -165,9 +166,6 @@ _key_management_system: Optional[KeyManagementSystem] = None
def get_model_cost_map(url: str): def get_model_cost_map(url: str):
verbose_logger.debug(
f"os.getenv('LITELLM_LOCAL_MODEL_COST_MAP', False): {os.environ['LITELLM_LOCAL_MODEL_COST_MAP']}"
)
if ( if (
os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == True os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == True
or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True" or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True"
@ -175,7 +173,6 @@ def get_model_cost_map(url: str):
import importlib.resources import importlib.resources
import json import json
verbose_logger.debug("RUNS LOCALLY")
with importlib.resources.open_text( with importlib.resources.open_text(
"litellm", "model_prices_and_context_window_backup.json" "litellm", "model_prices_and_context_window_backup.json"
) as f: ) as f:

View file

@ -14,7 +14,7 @@ import litellm
class LangFuseLogger: class LangFuseLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self, langfuse_public_key=None, langfuse_secret=None):
try: try:
from langfuse import Langfuse from langfuse import Langfuse
except Exception as e: except Exception as e:
@ -22,8 +22,8 @@ class LangFuseLogger:
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\033[0m" f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\033[0m"
) )
# Instance variables # Instance variables
self.secret_key = os.getenv("LANGFUSE_SECRET_KEY") self.secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY")
self.public_key = os.getenv("LANGFUSE_PUBLIC_KEY") self.public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY")
self.langfuse_host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com") self.langfuse_host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
self.langfuse_release = os.getenv("LANGFUSE_RELEASE") self.langfuse_release = os.getenv("LANGFUSE_RELEASE")
self.langfuse_debug = os.getenv("LANGFUSE_DEBUG") self.langfuse_debug = os.getenv("LANGFUSE_DEBUG")

View file

@ -321,6 +321,7 @@ class LiteLLM_VerificationToken(LiteLLMBase):
aliases: Dict = {} aliases: Dict = {}
config: Dict = {} config: Dict = {}
user_id: Optional[str] = None user_id: Optional[str] = None
team_id: Optional[str] = None
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
metadata: Dict = {} metadata: Dict = {}
tpm_limit: Optional[int] = None tpm_limit: Optional[int] = None

View file

@ -1024,6 +1024,24 @@ class ProxyConfig:
m["litellm_params"]["api_key"] = f"os.environ/{key_name}" m["litellm_params"]["api_key"] = f"os.environ/{key_name}"
await prisma_client.insert_data(data=new_config, table_name="config") await prisma_client.insert_data(data=new_config, table_name="config")
async def load_team_config(self, team_id: str):
"""
- for a given team id
- return the relevant completion() call params
"""
all_teams_config = litellm.default_team_settings
team_config: dict = {}
if all_teams_config is None:
return team_config
for team in all_teams_config:
if team_id == team["team_id"]:
team_config = team
break
for k, v in team_config.items():
if isinstance(v, str) and v.startswith("os.environ/"):
team_config[k] = litellm.get_secret(v)
return team_config
async def load_config( async def load_config(
self, router: Optional[litellm.Router], config_file_path: str self, router: Optional[litellm.Router], config_file_path: str
): ):
@ -2040,6 +2058,21 @@ async def chat_completion(
data["metadata"]["headers"] = _headers data["metadata"]["headers"] = _headers
data["metadata"]["endpoint"] = str(request.url) data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
global user_temperature, user_request_timeout, user_max_tokens, user_api_base global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli # override with user settings, these are params passed via cli
if user_temperature: if user_temperature:
@ -2215,6 +2248,21 @@ async def embeddings(
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["endpoint"] = str(request.url) data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
router_model_names = ( router_model_names = (
[m["model_name"] for m in llm_model_list] [m["model_name"] for m in llm_model_list]
if llm_model_list is not None if llm_model_list is not None
@ -2361,6 +2409,21 @@ async def image_generation(
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["endpoint"] = str(request.url) data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
router_model_names = ( router_model_names = (
[m["model_name"] for m in llm_model_list] [m["model_name"] for m in llm_model_list]
if llm_model_list is not None if llm_model_list is not None

View file

@ -0,0 +1,36 @@
#### What this tests ####
# This tests if setting team_config actually works
import sys, os
import traceback
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.proxy.proxy_server import ProxyConfig
@pytest.mark.asyncio
async def test_team_config():
litellm.default_team_settings = [
{
"team_id": "my-special-team",
"success_callback": ["langfuse"],
"langfuse_public_key": "os.environ/LANGFUSE_PUB_KEY_2",
"langfuse_secret": "os.environ/LANGFUSE_PRIVATE_KEY_2",
}
]
proxyconfig = ProxyConfig()
team_config = await proxyconfig.load_team_config(team_id="my-special-team")
assert len(team_config) > 0
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
}
team_config.pop("team_id")
response = litellm.completion(**{**data, **team_config})
print(f"response: {response}")

View file

@ -752,6 +752,8 @@ class Logging:
function_id, function_id,
dynamic_success_callbacks=None, dynamic_success_callbacks=None,
dynamic_async_success_callbacks=None, dynamic_async_success_callbacks=None,
langfuse_public_key=None,
langfuse_secret=None,
): ):
if call_type not in [item.value for item in CallTypes]: if call_type not in [item.value for item in CallTypes]:
allowed_values = ", ".join([item.value for item in CallTypes]) allowed_values = ", ".join([item.value for item in CallTypes])
@ -780,6 +782,9 @@ class Logging:
self.dynamic_async_success_callbacks = ( self.dynamic_async_success_callbacks = (
dynamic_async_success_callbacks or [] dynamic_async_success_callbacks or []
) # callbacks set for just that call ) # callbacks set for just that call
## DYNAMIC LANGFUSE KEYS ##
self.langfuse_public_key = langfuse_public_key
self.langfuse_secret = langfuse_secret
def update_environment_variables( def update_environment_variables(
self, model, user, optional_params, litellm_params, **additional_params self, model, user, optional_params, litellm_params, **additional_params
@ -1211,7 +1216,9 @@ class Logging:
if "complete_streaming_response" not in kwargs: if "complete_streaming_response" not in kwargs:
break break
else: else:
print_verbose("reaches langfuse for streaming logging!") print_verbose(
"reaches langsmith for streaming logging!"
)
result = kwargs["complete_streaming_response"] result = kwargs["complete_streaming_response"]
langsmithLogger.log_event( langsmithLogger.log_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
@ -1279,7 +1286,10 @@ class Logging:
print_verbose("reaches langfuse for streaming logging!") print_verbose("reaches langfuse for streaming logging!")
result = kwargs["complete_streaming_response"] result = kwargs["complete_streaming_response"]
if langFuseLogger is None: if langFuseLogger is None:
langFuseLogger = LangFuseLogger() langFuseLogger = LangFuseLogger(
langfuse_public_key=self.langfuse_public_key,
langfuse_secret=self.langfuse_secret,
)
langFuseLogger.log_event( langFuseLogger.log_event(
kwargs=kwargs, kwargs=kwargs,
response_obj=result, response_obj=result,
@ -1965,7 +1975,7 @@ def client(original_function):
# Pop the async items from success_callback in reverse order to avoid index issues # Pop the async items from success_callback in reverse order to avoid index issues
for index in reversed(removed_async_items): for index in reversed(removed_async_items):
kwargs["success_callback"].pop(index) kwargs["success_callback"].pop(index)
dynamic_success_callbacks = kwargs["success_callback"] dynamic_success_callbacks = kwargs.pop("success_callback")
if add_breadcrumb: if add_breadcrumb:
add_breadcrumb( add_breadcrumb(
@ -2030,6 +2040,8 @@ def client(original_function):
start_time=start_time, start_time=start_time,
dynamic_success_callbacks=dynamic_success_callbacks, dynamic_success_callbacks=dynamic_success_callbacks,
dynamic_async_success_callbacks=dynamic_async_success_callbacks, dynamic_async_success_callbacks=dynamic_async_success_callbacks,
langfuse_public_key=kwargs.pop("langfuse_public_key", None),
langfuse_secret=kwargs.pop("langfuse_secret", None),
) )
## check if metadata is passed in ## check if metadata is passed in
litellm_params = {} litellm_params = {}
@ -2041,7 +2053,7 @@ def client(original_function):
optional_params={}, optional_params={},
litellm_params=litellm_params, litellm_params=litellm_params,
) )
return logging_obj return logging_obj, kwargs
except Exception as e: except Exception as e:
import logging import logging
@ -2111,7 +2123,7 @@ def client(original_function):
try: try:
if logging_obj is None: if logging_obj is None:
logging_obj = function_setup(start_time, *args, **kwargs) logging_obj, kwargs = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj kwargs["litellm_logging_obj"] = logging_obj
# CHECK FOR 'os.environ/' in kwargs # CHECK FOR 'os.environ/' in kwargs
@ -2346,7 +2358,7 @@ def client(original_function):
try: try:
if logging_obj is None: if logging_obj is None:
logging_obj = function_setup(start_time, *args, **kwargs) logging_obj, kwargs = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj kwargs["litellm_logging_obj"] = logging_obj
# [OPTIONAL] CHECK BUDGET # [OPTIONAL] CHECK BUDGET