forked from phoenix/litellm-mirror
(feat) - allow using os.environ/ vars for any value on config.yaml (#6276)
* add check for os.environ vars when readin config.yaml * use base class for reading from config.yaml * fix import * fix linting * add unit tests for base config class * fix order of reading elements from config.yaml * unit tests for reading configs from files * fix user_config_file_path * use simpler implementation * use helper to get_config * working unit tests for reading configs
This commit is contained in:
parent
a0d45ba516
commit
19eff1a4b4
4 changed files with 289 additions and 19 deletions
|
@ -1,15 +1,48 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: fake-openai-endpoint
|
################################################################################
|
||||||
|
# Azure
|
||||||
|
- model_name: gpt-4o-mini
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/fake
|
model: azure/gpt-4o-mini
|
||||||
api_key: fake-key
|
api_base: https://amazin-prod.openai.azure.com
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railwaz.app/
|
api_key: "os.environ/AZURE_GPT_4O"
|
||||||
- model_name: db-openai-endpoint
|
deployment_id: gpt-4o-mini
|
||||||
|
- model_name: gpt-4o
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/gpt-5
|
model: azure/gpt-4o
|
||||||
api_key: fake-key
|
api_base: https://very-cool-prod.openai.azure.com
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railwxaz.app/
|
api_key: "os.environ/AZURE_GPT_4O"
|
||||||
|
deployment_id: gpt-4o
|
||||||
|
|
||||||
litellm_settings:
|
################################################################################
|
||||||
callbacks: ["arize"]
|
# Fireworks
|
||||||
|
- model_name: fireworks-llama-v3p1-405b-instruct
|
||||||
|
litellm_params:
|
||||||
|
model: fireworks_ai/accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||||
|
api_key: "os.environ/FIREWORKS"
|
||||||
|
- model_name: fireworks-llama-v3p1-70b-instruct
|
||||||
|
litellm_params:
|
||||||
|
model: fireworks_ai/accounts/fireworks/models/llama-v3p1-70b-instruct
|
||||||
|
api_key: "os.environ/FIREWORKS"
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
|
||||||
|
litellm_settings: # module level litellm settings - https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py
|
||||||
|
success_callback: ["prometheus"]
|
||||||
|
service_callback: ["prometheus_system"]
|
||||||
|
drop_params: False # Raise an exception if the openai param being passed in isn't supported.
|
||||||
|
cache: false
|
||||||
|
default_internal_user_params:
|
||||||
|
user_role: os.environ/DEFAULT_USER_ROLE
|
||||||
|
|
||||||
|
success_callback: ["s3"]
|
||||||
|
s3_callback_params:
|
||||||
|
s3_bucket_name: logs-bucket-litellm # AWS Bucket Name for S3
|
||||||
|
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||||
|
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||||
|
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
||||||
|
s3_path: my-test-path # [OPTIONAL] set path in bucket you want to write logs to
|
||||||
|
s3_endpoint_url: https://s3.amazonaws.com # [OPTIONAL] S3 endpoint URL, if you want to use Backblaze/cloudflare s3 buckets
|
||||||
|
|
||||||
|
router_settings:
|
||||||
|
routing_strategy: simple-shuffle # "simple-shuffle" shown to result in highest throughput. https://docs.litellm.ai/docs/proxy/configs#load-balancing
|
||||||
|
|
|
@ -464,7 +464,7 @@ user_temperature = None
|
||||||
user_telemetry = True
|
user_telemetry = True
|
||||||
user_config = None
|
user_config = None
|
||||||
user_headers = None
|
user_headers = None
|
||||||
user_config_file_path = f"config_{int(time.time())}.yaml"
|
user_config_file_path: Optional[str] = None
|
||||||
local_logging = True # writes logs to a local api_log.json file for debugging
|
local_logging = True # writes logs to a local api_log.json file for debugging
|
||||||
experimental = False
|
experimental = False
|
||||||
#### GLOBAL VARIABLES ####
|
#### GLOBAL VARIABLES ####
|
||||||
|
@ -1373,7 +1373,19 @@ class ProxyConfig:
|
||||||
_, file_extension = os.path.splitext(config_file_path)
|
_, file_extension = os.path.splitext(config_file_path)
|
||||||
return file_extension.lower() == ".yaml" or file_extension.lower() == ".yml"
|
return file_extension.lower() == ".yaml" or file_extension.lower() == ".yml"
|
||||||
|
|
||||||
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
|
async def _get_config_from_file(
|
||||||
|
self, config_file_path: Optional[str] = None
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Given a config file path, load the config from the file.
|
||||||
|
|
||||||
|
If `store_model_in_db` is True, then read the DB and update the config with the DB values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file_path (str): path to the config file
|
||||||
|
Returns:
|
||||||
|
dict: config
|
||||||
|
"""
|
||||||
global prisma_client, user_config_file_path
|
global prisma_client, user_config_file_path
|
||||||
|
|
||||||
file_path = config_file_path or user_config_file_path
|
file_path = config_file_path or user_config_file_path
|
||||||
|
@ -1384,6 +1396,8 @@ class ProxyConfig:
|
||||||
if os.path.exists(f"{file_path}"):
|
if os.path.exists(f"{file_path}"):
|
||||||
with open(f"{file_path}", "r") as config_file:
|
with open(f"{file_path}", "r") as config_file:
|
||||||
config = yaml.safe_load(config_file)
|
config = yaml.safe_load(config_file)
|
||||||
|
elif file_path is not None:
|
||||||
|
raise Exception(f"Config file not found: {file_path}")
|
||||||
else:
|
else:
|
||||||
config = {
|
config = {
|
||||||
"model_list": [],
|
"model_list": [],
|
||||||
|
@ -1449,6 +1463,43 @@ class ProxyConfig:
|
||||||
with open(f"{user_config_file_path}", "w") as config_file:
|
with open(f"{user_config_file_path}", "w") as config_file:
|
||||||
yaml.dump(new_config, config_file, default_flow_style=False)
|
yaml.dump(new_config, config_file, default_flow_style=False)
|
||||||
|
|
||||||
|
def _check_for_os_environ_vars(
|
||||||
|
self, config: dict, depth: int = 0, max_depth: int = 10
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Check for os.environ/ variables in the config and replace them with the actual values.
|
||||||
|
Includes a depth limit to prevent infinite recursion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): The configuration dictionary to process.
|
||||||
|
depth (int): Current recursion depth.
|
||||||
|
max_depth (int): Maximum allowed recursion depth.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Processed configuration dictionary.
|
||||||
|
"""
|
||||||
|
if depth > max_depth:
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
f"Maximum recursion depth ({max_depth}) reached while processing config."
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
for key, value in config.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
config[key] = self._check_for_os_environ_vars(
|
||||||
|
config=value, depth=depth + 1, max_depth=max_depth
|
||||||
|
)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
item = self._check_for_os_environ_vars(
|
||||||
|
config=item, depth=depth + 1, max_depth=max_depth
|
||||||
|
)
|
||||||
|
# if the value is a string and starts with "os.environ/" - then it's an environment variable
|
||||||
|
elif isinstance(value, str) and value.startswith("os.environ/"):
|
||||||
|
config[key] = get_secret(value)
|
||||||
|
return config
|
||||||
|
|
||||||
async def load_team_config(self, team_id: str):
|
async def load_team_config(self, team_id: str):
|
||||||
"""
|
"""
|
||||||
- for a given team id
|
- for a given team id
|
||||||
|
@ -1492,14 +1543,21 @@ class ProxyConfig:
|
||||||
## INIT PROXY REDIS USAGE CLIENT ##
|
## INIT PROXY REDIS USAGE CLIENT ##
|
||||||
redis_usage_cache = litellm.cache.cache
|
redis_usage_cache = litellm.cache.cache
|
||||||
|
|
||||||
async def load_config( # noqa: PLR0915
|
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
|
||||||
self, router: Optional[litellm.Router], config_file_path: str
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Load config values into proxy global state
|
Load config file
|
||||||
"""
|
Supports reading from:
|
||||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, use_queue, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings
|
- .yaml file paths
|
||||||
|
- LiteLLM connected DB
|
||||||
|
- GCS
|
||||||
|
- S3
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file_path (str): path to the config file
|
||||||
|
Returns:
|
||||||
|
dict: config
|
||||||
|
|
||||||
|
"""
|
||||||
# Load existing config
|
# Load existing config
|
||||||
if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
|
if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
|
||||||
bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME")
|
bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME")
|
||||||
|
@ -1521,7 +1579,7 @@ class ProxyConfig:
|
||||||
raise Exception("Unable to load config from given source.")
|
raise Exception("Unable to load config from given source.")
|
||||||
else:
|
else:
|
||||||
# default to file
|
# default to file
|
||||||
config = await self.get_config(config_file_path=config_file_path)
|
config = await self._get_config_from_file(config_file_path=config_file_path)
|
||||||
## PRINT YAML FOR CONFIRMING IT WORKS
|
## PRINT YAML FOR CONFIRMING IT WORKS
|
||||||
printed_yaml = copy.deepcopy(config)
|
printed_yaml = copy.deepcopy(config)
|
||||||
printed_yaml.pop("environment_variables", None)
|
printed_yaml.pop("environment_variables", None)
|
||||||
|
@ -1530,6 +1588,20 @@ class ProxyConfig:
|
||||||
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
|
f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config = self._check_for_os_environ_vars(config=config)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
async def load_config( # noqa: PLR0915
|
||||||
|
self, router: Optional[litellm.Router], config_file_path: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load config values into proxy global state
|
||||||
|
"""
|
||||||
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, use_queue, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings
|
||||||
|
|
||||||
|
config: dict = await self.get_config(config_file_path=config_file_path)
|
||||||
|
|
||||||
## ENVIRONMENT VARIABLES
|
## ENVIRONMENT VARIABLES
|
||||||
environment_variables = config.get("environment_variables", None)
|
environment_variables = config.get("environment_variables", None)
|
||||||
if environment_variables:
|
if environment_variables:
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
model_list:
|
||||||
|
################################################################################
|
||||||
|
# Azure
|
||||||
|
- model_name: gpt-4o-mini
|
||||||
|
litellm_params:
|
||||||
|
model: azure/gpt-4o-mini
|
||||||
|
api_base: https://amazin-prod.openai.azure.com
|
||||||
|
api_key: "os.environ/AZURE_GPT_4O"
|
||||||
|
deployment_id: gpt-4o-mini
|
||||||
|
- model_name: gpt-4o
|
||||||
|
litellm_params:
|
||||||
|
model: azure/gpt-4o
|
||||||
|
api_base: https://very-cool-prod.openai.azure.com
|
||||||
|
api_key: "os.environ/AZURE_GPT_4O"
|
||||||
|
deployment_id: gpt-4o
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Fireworks
|
||||||
|
- model_name: fireworks-llama-v3p1-405b-instruct
|
||||||
|
litellm_params:
|
||||||
|
model: fireworks_ai/accounts/fireworks/models/llama-v3p1-405b-instruct
|
||||||
|
api_key: "os.environ/FIREWORKS"
|
||||||
|
- model_name: fireworks-llama-v3p1-70b-instruct
|
||||||
|
litellm_params:
|
||||||
|
model: fireworks_ai/accounts/fireworks/models/llama-v3p1-70b-instruct
|
||||||
|
api_key: "os.environ/FIREWORKS"
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+
|
||||||
|
litellm_settings: # module level litellm settings - https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py
|
||||||
|
success_callback: ["prometheus"]
|
||||||
|
service_callback: ["prometheus_system"]
|
||||||
|
drop_params: False # Raise an exception if the openai param being passed in isn't supported.
|
||||||
|
cache: false
|
||||||
|
default_internal_user_params:
|
||||||
|
user_role: os.environ/DEFAULT_USER_ROLE
|
||||||
|
|
||||||
|
success_callback: ["s3"]
|
||||||
|
s3_callback_params:
|
||||||
|
s3_bucket_name: logs-bucket-litellm # AWS Bucket Name for S3
|
||||||
|
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||||
|
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||||
|
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
||||||
|
s3_path: my-test-path # [OPTIONAL] set path in bucket you want to write logs to
|
||||||
|
s3_endpoint_url: https://s3.amazonaws.com # [OPTIONAL] S3 endpoint URL, if you want to use Backblaze/cloudflare s3 buckets
|
||||||
|
|
||||||
|
router_settings:
|
||||||
|
routing_strategy: simple-shuffle # "simple-shuffle" shown to result in highest throughput. https://docs.litellm.ai/docs/proxy/configs#load-balancing
|
117
tests/local_testing/test_proxy_config_unit_test.py
Normal file
117
tests/local_testing/test_proxy_config_unit_test.py
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from unittest import mock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import litellm.proxy
|
||||||
|
import litellm.proxy.proxy_server
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
|
# this file is to test litellm/proxy
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import ProxyConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_reading_configs_from_files():
|
||||||
|
"""
|
||||||
|
Test that the config is read correctly from the files in the example_config_yaml folder
|
||||||
|
"""
|
||||||
|
proxy_config_instance = ProxyConfig()
|
||||||
|
current_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
example_config_yaml_path = os.path.join(current_path, "example_config_yaml")
|
||||||
|
|
||||||
|
# get all the files from example_config_yaml
|
||||||
|
files = os.listdir(example_config_yaml_path)
|
||||||
|
print(files)
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
config_path = os.path.join(example_config_yaml_path, file)
|
||||||
|
config = await proxy_config_instance.get_config(config_file_path=config_path)
|
||||||
|
print(config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_config_from_bad_file_path():
|
||||||
|
"""
|
||||||
|
Raise an exception if the file path is not valid
|
||||||
|
"""
|
||||||
|
proxy_config_instance = ProxyConfig()
|
||||||
|
config_path = "non-existent-file.yaml"
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
config = await proxy_config_instance.get_config(config_file_path=config_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_config_file_with_os_environ_vars():
|
||||||
|
"""
|
||||||
|
Ensures os.environ variables are read correctly from config.yaml
|
||||||
|
Following vars are set as os.environ variables in the config.yaml file
|
||||||
|
- DEFAULT_USER_ROLE
|
||||||
|
- AWS_ACCESS_KEY_ID
|
||||||
|
- AWS_SECRET_ACCESS_KEY
|
||||||
|
- AZURE_GPT_4O
|
||||||
|
- FIREWORKS
|
||||||
|
"""
|
||||||
|
|
||||||
|
_env_vars_for_testing = {
|
||||||
|
"DEFAULT_USER_ROLE": "admin",
|
||||||
|
"AWS_ACCESS_KEY_ID": "1234567890",
|
||||||
|
"AWS_SECRET_ACCESS_KEY": "1234567890",
|
||||||
|
"AZURE_GPT_4O": "1234567890",
|
||||||
|
"FIREWORKS": "1234567890",
|
||||||
|
}
|
||||||
|
|
||||||
|
_old_env_vars = {}
|
||||||
|
for key, value in _env_vars_for_testing.items():
|
||||||
|
if key in os.environ:
|
||||||
|
_old_env_vars[key] = os.environ.get(key)
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
# Read config
|
||||||
|
proxy_config_instance = ProxyConfig()
|
||||||
|
current_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
config_path = os.path.join(
|
||||||
|
current_path, "example_config_yaml", "config_with_env_vars.yaml"
|
||||||
|
)
|
||||||
|
config = await proxy_config_instance.get_config(config_file_path=config_path)
|
||||||
|
print(config)
|
||||||
|
|
||||||
|
# Add assertions
|
||||||
|
assert (
|
||||||
|
config["litellm_settings"]["default_internal_user_params"]["user_role"]
|
||||||
|
== "admin"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
config["litellm_settings"]["s3_callback_params"]["s3_aws_access_key_id"]
|
||||||
|
== "1234567890"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
config["litellm_settings"]["s3_callback_params"]["s3_aws_secret_access_key"]
|
||||||
|
== "1234567890"
|
||||||
|
)
|
||||||
|
|
||||||
|
for model in config["model_list"]:
|
||||||
|
if "azure" in model["litellm_params"]["model"]:
|
||||||
|
assert model["litellm_params"]["api_key"] == "1234567890"
|
||||||
|
elif "fireworks" in model["litellm_params"]["model"]:
|
||||||
|
assert model["litellm_params"]["api_key"] == "1234567890"
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
for key, value in _env_vars_for_testing.items():
|
||||||
|
if key in _old_env_vars:
|
||||||
|
os.environ[key] = _old_env_vars[key]
|
||||||
|
else:
|
||||||
|
del os.environ[key]
|
Loading…
Add table
Add a link
Reference in a new issue