forked from phoenix/litellm-mirror
feat(utils.py): support dynamic langfuse params and team settings on proxy
This commit is contained in:
parent
b79a6607b2
commit
a301d8aa4b
6 changed files with 122 additions and 13 deletions
|
@ -146,6 +146,7 @@ suppress_debug_info = False
|
|||
dynamodb_table_name: Optional[str] = None
|
||||
s3_callback_params: Optional[Dict] = None
|
||||
default_key_generate_params: Optional[Dict] = None
|
||||
default_team_settings: Optional[List] = None
|
||||
#### RELIABILITY ####
|
||||
request_timeout: Optional[float] = 6000
|
||||
num_retries: Optional[int] = None # per model endpoint
|
||||
|
@ -165,9 +166,6 @@ _key_management_system: Optional[KeyManagementSystem] = None
|
|||
|
||||
|
||||
def get_model_cost_map(url: str):
|
||||
verbose_logger.debug(
|
||||
f"os.getenv('LITELLM_LOCAL_MODEL_COST_MAP', False): {os.environ['LITELLM_LOCAL_MODEL_COST_MAP']}"
|
||||
)
|
||||
if (
|
||||
os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == True
|
||||
or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True"
|
||||
|
@ -175,7 +173,6 @@ def get_model_cost_map(url: str):
|
|||
import importlib.resources
|
||||
import json
|
||||
|
||||
verbose_logger.debug("RUNS LOCALLY")
|
||||
with importlib.resources.open_text(
|
||||
"litellm", "model_prices_and_context_window_backup.json"
|
||||
) as f:
|
||||
|
|
|
@ -14,7 +14,7 @@ import litellm
|
|||
|
||||
class LangFuseLogger:
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
def __init__(self, langfuse_public_key=None, langfuse_secret=None):
|
||||
try:
|
||||
from langfuse import Langfuse
|
||||
except Exception as e:
|
||||
|
@ -22,8 +22,8 @@ class LangFuseLogger:
|
|||
f"\033[91mLangfuse not installed, try running 'pip install langfuse' to fix this error: {e}\033[0m"
|
||||
)
|
||||
# Instance variables
|
||||
self.secret_key = os.getenv("LANGFUSE_SECRET_KEY")
|
||||
self.public_key = os.getenv("LANGFUSE_PUBLIC_KEY")
|
||||
self.secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY")
|
||||
self.public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY")
|
||||
self.langfuse_host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
|
||||
self.langfuse_release = os.getenv("LANGFUSE_RELEASE")
|
||||
self.langfuse_debug = os.getenv("LANGFUSE_DEBUG")
|
||||
|
|
|
@ -321,6 +321,7 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
|||
aliases: Dict = {}
|
||||
config: Dict = {}
|
||||
user_id: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
metadata: Dict = {}
|
||||
tpm_limit: Optional[int] = None
|
||||
|
|
|
@ -1024,6 +1024,24 @@ class ProxyConfig:
|
|||
m["litellm_params"]["api_key"] = f"os.environ/{key_name}"
|
||||
await prisma_client.insert_data(data=new_config, table_name="config")
|
||||
|
||||
async def load_team_config(self, team_id: str):
|
||||
"""
|
||||
- for a given team id
|
||||
- return the relevant completion() call params
|
||||
"""
|
||||
all_teams_config = litellm.default_team_settings
|
||||
team_config: dict = {}
|
||||
if all_teams_config is None:
|
||||
return team_config
|
||||
for team in all_teams_config:
|
||||
if team_id == team["team_id"]:
|
||||
team_config = team
|
||||
break
|
||||
for k, v in team_config.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
team_config[k] = litellm.get_secret(v)
|
||||
return team_config
|
||||
|
||||
async def load_config(
|
||||
self, router: Optional[litellm.Router], config_file_path: str
|
||||
):
|
||||
|
@ -2040,6 +2058,21 @@ async def chat_completion(
|
|||
data["metadata"]["headers"] = _headers
|
||||
data["metadata"]["endpoint"] = str(request.url)
|
||||
|
||||
### TEAM-SPECIFIC PARAMS ###
|
||||
if user_api_key_dict.team_id is not None:
|
||||
team_config = await proxy_config.load_team_config(
|
||||
team_id=user_api_key_dict.team_id
|
||||
)
|
||||
if len(team_config) == 0:
|
||||
pass
|
||||
else:
|
||||
team_id = team_config.pop("team_id", None)
|
||||
data["metadata"]["team_id"] = team_id
|
||||
data = {
|
||||
**team_config,
|
||||
**data,
|
||||
} # add the team-specific configs to the completion call
|
||||
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
|
@ -2215,6 +2248,21 @@ async def embeddings(
|
|||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||
data["metadata"]["endpoint"] = str(request.url)
|
||||
|
||||
### TEAM-SPECIFIC PARAMS ###
|
||||
if user_api_key_dict.team_id is not None:
|
||||
team_config = await proxy_config.load_team_config(
|
||||
team_id=user_api_key_dict.team_id
|
||||
)
|
||||
if len(team_config) == 0:
|
||||
pass
|
||||
else:
|
||||
team_id = team_config.pop("team_id", None)
|
||||
data["metadata"]["team_id"] = team_id
|
||||
data = {
|
||||
**team_config,
|
||||
**data,
|
||||
} # add the team-specific configs to the completion call
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
|
@ -2361,6 +2409,21 @@ async def image_generation(
|
|||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||
data["metadata"]["endpoint"] = str(request.url)
|
||||
|
||||
### TEAM-SPECIFIC PARAMS ###
|
||||
if user_api_key_dict.team_id is not None:
|
||||
team_config = await proxy_config.load_team_config(
|
||||
team_id=user_api_key_dict.team_id
|
||||
)
|
||||
if len(team_config) == 0:
|
||||
pass
|
||||
else:
|
||||
team_id = team_config.pop("team_id", None)
|
||||
data["metadata"]["team_id"] = team_id
|
||||
data = {
|
||||
**team_config,
|
||||
**data,
|
||||
} # add the team-specific configs to the completion call
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
|
|
36
litellm/tests/test_team_config.py
Normal file
36
litellm/tests/test_team_config.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
#### What this tests ####
|
||||
# This tests if setting team_config actually works
|
||||
import sys, os
|
||||
import traceback
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.proxy.proxy_server import ProxyConfig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_config():
|
||||
litellm.default_team_settings = [
|
||||
{
|
||||
"team_id": "my-special-team",
|
||||
"success_callback": ["langfuse"],
|
||||
"langfuse_public_key": "os.environ/LANGFUSE_PUB_KEY_2",
|
||||
"langfuse_secret": "os.environ/LANGFUSE_PRIVATE_KEY_2",
|
||||
}
|
||||
]
|
||||
proxyconfig = ProxyConfig()
|
||||
|
||||
team_config = await proxyconfig.load_team_config(team_id="my-special-team")
|
||||
assert len(team_config) > 0
|
||||
|
||||
data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
|
||||
}
|
||||
team_config.pop("team_id")
|
||||
response = litellm.completion(**{**data, **team_config})
|
||||
|
||||
print(f"response: {response}")
|
|
@ -752,6 +752,8 @@ class Logging:
|
|||
function_id,
|
||||
dynamic_success_callbacks=None,
|
||||
dynamic_async_success_callbacks=None,
|
||||
langfuse_public_key=None,
|
||||
langfuse_secret=None,
|
||||
):
|
||||
if call_type not in [item.value for item in CallTypes]:
|
||||
allowed_values = ", ".join([item.value for item in CallTypes])
|
||||
|
@ -780,6 +782,9 @@ class Logging:
|
|||
self.dynamic_async_success_callbacks = (
|
||||
dynamic_async_success_callbacks or []
|
||||
) # callbacks set for just that call
|
||||
## DYNAMIC LANGFUSE KEYS ##
|
||||
self.langfuse_public_key = langfuse_public_key
|
||||
self.langfuse_secret = langfuse_secret
|
||||
|
||||
def update_environment_variables(
|
||||
self, model, user, optional_params, litellm_params, **additional_params
|
||||
|
@ -1211,7 +1216,9 @@ class Logging:
|
|||
if "complete_streaming_response" not in kwargs:
|
||||
break
|
||||
else:
|
||||
print_verbose("reaches langfuse for streaming logging!")
|
||||
print_verbose(
|
||||
"reaches langsmith for streaming logging!"
|
||||
)
|
||||
result = kwargs["complete_streaming_response"]
|
||||
langsmithLogger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
|
@ -1279,7 +1286,10 @@ class Logging:
|
|||
print_verbose("reaches langfuse for streaming logging!")
|
||||
result = kwargs["complete_streaming_response"]
|
||||
if langFuseLogger is None:
|
||||
langFuseLogger = LangFuseLogger()
|
||||
langFuseLogger = LangFuseLogger(
|
||||
langfuse_public_key=self.langfuse_public_key,
|
||||
langfuse_secret=self.langfuse_secret,
|
||||
)
|
||||
langFuseLogger.log_event(
|
||||
kwargs=kwargs,
|
||||
response_obj=result,
|
||||
|
@ -1965,7 +1975,7 @@ def client(original_function):
|
|||
# Pop the async items from success_callback in reverse order to avoid index issues
|
||||
for index in reversed(removed_async_items):
|
||||
kwargs["success_callback"].pop(index)
|
||||
dynamic_success_callbacks = kwargs["success_callback"]
|
||||
dynamic_success_callbacks = kwargs.pop("success_callback")
|
||||
|
||||
if add_breadcrumb:
|
||||
add_breadcrumb(
|
||||
|
@ -2030,6 +2040,8 @@ def client(original_function):
|
|||
start_time=start_time,
|
||||
dynamic_success_callbacks=dynamic_success_callbacks,
|
||||
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
|
||||
langfuse_public_key=kwargs.pop("langfuse_public_key", None),
|
||||
langfuse_secret=kwargs.pop("langfuse_secret", None),
|
||||
)
|
||||
## check if metadata is passed in
|
||||
litellm_params = {}
|
||||
|
@ -2041,7 +2053,7 @@ def client(original_function):
|
|||
optional_params={},
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
return logging_obj
|
||||
return logging_obj, kwargs
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
|
@ -2111,7 +2123,7 @@ def client(original_function):
|
|||
|
||||
try:
|
||||
if logging_obj is None:
|
||||
logging_obj = function_setup(start_time, *args, **kwargs)
|
||||
logging_obj, kwargs = function_setup(start_time, *args, **kwargs)
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
|
||||
# CHECK FOR 'os.environ/' in kwargs
|
||||
|
@ -2346,7 +2358,7 @@ def client(original_function):
|
|||
|
||||
try:
|
||||
if logging_obj is None:
|
||||
logging_obj = function_setup(start_time, *args, **kwargs)
|
||||
logging_obj, kwargs = function_setup(start_time, *args, **kwargs)
|
||||
kwargs["litellm_logging_obj"] = logging_obj
|
||||
|
||||
# [OPTIONAL] CHECK BUDGET
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue