From 19eff1a4b4a499f3d35a721f60eeb391d02757a1 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 19 Oct 2024 09:00:27 +0530 Subject: [PATCH] (feat) - allow using os.environ/ vars for any value on config.yaml (#6276) * add check for os.environ vars when readin config.yaml * use base class for reading from config.yaml * fix import * fix linting * add unit tests for base config class * fix order of reading elements from config.yaml * unit tests for reading configs from files * fix user_config_file_path * use simpler implementation * use helper to get_config * working unit tests for reading configs --- litellm/proxy/proxy_config.yaml | 53 ++++++-- litellm/proxy/proxy_server.py | 90 ++++++++++++-- .../config_with_env_vars.yaml | 48 +++++++ .../test_proxy_config_unit_test.py | 117 ++++++++++++++++++ 4 files changed, 289 insertions(+), 19 deletions(-) create mode 100644 tests/local_testing/example_config_yaml/config_with_env_vars.yaml create mode 100644 tests/local_testing/test_proxy_config_unit_test.py diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 707f76ce8..bae738c73 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,15 +1,48 @@ model_list: - - model_name: fake-openai-endpoint + ################################################################################ + # Azure + - model_name: gpt-4o-mini litellm_params: - model: openai/fake - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railwaz.app/ - - model_name: db-openai-endpoint + model: azure/gpt-4o-mini + api_base: https://amazin-prod.openai.azure.com + api_key: "os.environ/AZURE_GPT_4O" + deployment_id: gpt-4o-mini + - model_name: gpt-4o litellm_params: - model: openai/gpt-5 - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railwxaz.app/ + model: azure/gpt-4o + api_base: https://very-cool-prod.openai.azure.com + api_key: "os.environ/AZURE_GPT_4O" + deployment_id: gpt-4o -litellm_settings: - callbacks: ["arize"] + ################################################################################ + # Fireworks + - model_name: fireworks-llama-v3p1-405b-instruct + litellm_params: + model: fireworks_ai/accounts/fireworks/models/llama-v3p1-405b-instruct + api_key: "os.environ/FIREWORKS" + - model_name: fireworks-llama-v3p1-70b-instruct + litellm_params: + model: fireworks_ai/accounts/fireworks/models/llama-v3p1-70b-instruct + api_key: "os.environ/FIREWORKS" + +general_settings: + alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+ +litellm_settings: # module level litellm settings - https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py + success_callback: ["prometheus"] + service_callback: ["prometheus_system"] + drop_params: False # Raise an exception if the openai param being passed in isn't supported. + cache: false + default_internal_user_params: + user_role: os.environ/DEFAULT_USER_ROLE + success_callback: ["s3"] + s3_callback_params: + s3_bucket_name: logs-bucket-litellm # AWS Bucket Name for S3 + s3_region_name: us-west-2 # AWS Region Name for S3 + s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/ to pass environment variables. This is AWS Access Key ID for S3 + s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3 + s3_path: my-test-path # [OPTIONAL] set path in bucket you want to write logs to + s3_endpoint_url: https://s3.amazonaws.com # [OPTIONAL] S3 endpoint URL, if you want to use Backblaze/cloudflare s3 buckets + +router_settings: + routing_strategy: simple-shuffle # "simple-shuffle" shown to result in highest throughput. https://docs.litellm.ai/docs/proxy/configs#load-balancing diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b3739e491..4d431bd87 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -464,7 +464,7 @@ user_temperature = None user_telemetry = True user_config = None user_headers = None -user_config_file_path = f"config_{int(time.time())}.yaml" +user_config_file_path: Optional[str] = None local_logging = True # writes logs to a local api_log.json file for debugging experimental = False #### GLOBAL VARIABLES #### @@ -1373,7 +1373,19 @@ class ProxyConfig: _, file_extension = os.path.splitext(config_file_path) return file_extension.lower() == ".yaml" or file_extension.lower() == ".yml" - async def get_config(self, config_file_path: Optional[str] = None) -> dict: + async def _get_config_from_file( + self, config_file_path: Optional[str] = None + ) -> dict: + """ + Given a config file path, load the config from the file. + + If `store_model_in_db` is True, then read the DB and update the config with the DB values. + + Args: + config_file_path (str): path to the config file + Returns: + dict: config + """ global prisma_client, user_config_file_path file_path = config_file_path or user_config_file_path @@ -1384,6 +1396,8 @@ class ProxyConfig: if os.path.exists(f"{file_path}"): with open(f"{file_path}", "r") as config_file: config = yaml.safe_load(config_file) + elif file_path is not None: + raise Exception(f"Config file not found: {file_path}") else: config = { "model_list": [], @@ -1449,6 +1463,43 @@ class ProxyConfig: with open(f"{user_config_file_path}", "w") as config_file: yaml.dump(new_config, config_file, default_flow_style=False) + def _check_for_os_environ_vars( + self, config: dict, depth: int = 0, max_depth: int = 10 + ) -> dict: + """ + Check for os.environ/ variables in the config and replace them with the actual values. + Includes a depth limit to prevent infinite recursion. + + Args: + config (dict): The configuration dictionary to process. + depth (int): Current recursion depth. + max_depth (int): Maximum allowed recursion depth. + + Returns: + dict: Processed configuration dictionary. + """ + if depth > max_depth: + verbose_proxy_logger.warning( + f"Maximum recursion depth ({max_depth}) reached while processing config." + ) + return config + + for key, value in config.items(): + if isinstance(value, dict): + config[key] = self._check_for_os_environ_vars( + config=value, depth=depth + 1, max_depth=max_depth + ) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + item = self._check_for_os_environ_vars( + config=item, depth=depth + 1, max_depth=max_depth + ) + # if the value is a string and starts with "os.environ/" - then it's an environment variable + elif isinstance(value, str) and value.startswith("os.environ/"): + config[key] = get_secret(value) + return config + async def load_team_config(self, team_id: str): """ - for a given team id @@ -1492,14 +1543,21 @@ class ProxyConfig: ## INIT PROXY REDIS USAGE CLIENT ## redis_usage_cache = litellm.cache.cache - async def load_config( # noqa: PLR0915 - self, router: Optional[litellm.Router], config_file_path: str - ): + async def get_config(self, config_file_path: Optional[str] = None) -> dict: """ - Load config values into proxy global state - """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, use_queue, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings + Load config file + Supports reading from: + - .yaml file paths + - LiteLLM connected DB + - GCS + - S3 + Args: + config_file_path (str): path to the config file + Returns: + dict: config + + """ # Load existing config if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None: bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME") @@ -1521,7 +1579,7 @@ class ProxyConfig: raise Exception("Unable to load config from given source.") else: # default to file - config = await self.get_config(config_file_path=config_file_path) + config = await self._get_config_from_file(config_file_path=config_file_path) ## PRINT YAML FOR CONFIRMING IT WORKS printed_yaml = copy.deepcopy(config) printed_yaml.pop("environment_variables", None) @@ -1530,6 +1588,20 @@ class ProxyConfig: f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}" ) + config = self._check_for_os_environ_vars(config=config) + + return config + + async def load_config( # noqa: PLR0915 + self, router: Optional[litellm.Router], config_file_path: str + ): + """ + Load config values into proxy global state + """ + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, use_queue, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings + + config: dict = await self.get_config(config_file_path=config_file_path) + ## ENVIRONMENT VARIABLES environment_variables = config.get("environment_variables", None) if environment_variables: diff --git a/tests/local_testing/example_config_yaml/config_with_env_vars.yaml b/tests/local_testing/example_config_yaml/config_with_env_vars.yaml new file mode 100644 index 000000000..bae738c73 --- /dev/null +++ b/tests/local_testing/example_config_yaml/config_with_env_vars.yaml @@ -0,0 +1,48 @@ +model_list: + ################################################################################ + # Azure + - model_name: gpt-4o-mini + litellm_params: + model: azure/gpt-4o-mini + api_base: https://amazin-prod.openai.azure.com + api_key: "os.environ/AZURE_GPT_4O" + deployment_id: gpt-4o-mini + - model_name: gpt-4o + litellm_params: + model: azure/gpt-4o + api_base: https://very-cool-prod.openai.azure.com + api_key: "os.environ/AZURE_GPT_4O" + deployment_id: gpt-4o + + ################################################################################ + # Fireworks + - model_name: fireworks-llama-v3p1-405b-instruct + litellm_params: + model: fireworks_ai/accounts/fireworks/models/llama-v3p1-405b-instruct + api_key: "os.environ/FIREWORKS" + - model_name: fireworks-llama-v3p1-70b-instruct + litellm_params: + model: fireworks_ai/accounts/fireworks/models/llama-v3p1-70b-instruct + api_key: "os.environ/FIREWORKS" + +general_settings: + alerting_threshold: 300 # sends alerts if requests hang for 5min+ and responses take 5min+ +litellm_settings: # module level litellm settings - https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py + success_callback: ["prometheus"] + service_callback: ["prometheus_system"] + drop_params: False # Raise an exception if the openai param being passed in isn't supported. + cache: false + default_internal_user_params: + user_role: os.environ/DEFAULT_USER_ROLE + + success_callback: ["s3"] + s3_callback_params: + s3_bucket_name: logs-bucket-litellm # AWS Bucket Name for S3 + s3_region_name: us-west-2 # AWS Region Name for S3 + s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/ to pass environment variables. This is AWS Access Key ID for S3 + s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3 + s3_path: my-test-path # [OPTIONAL] set path in bucket you want to write logs to + s3_endpoint_url: https://s3.amazonaws.com # [OPTIONAL] S3 endpoint URL, if you want to use Backblaze/cloudflare s3 buckets + +router_settings: + routing_strategy: simple-shuffle # "simple-shuffle" shown to result in highest throughput. https://docs.litellm.ai/docs/proxy/configs#load-balancing diff --git a/tests/local_testing/test_proxy_config_unit_test.py b/tests/local_testing/test_proxy_config_unit_test.py new file mode 100644 index 000000000..bb51ce726 --- /dev/null +++ b/tests/local_testing/test_proxy_config_unit_test.py @@ -0,0 +1,117 @@ +import os +import sys +import traceback +from unittest import mock +import pytest + +from dotenv import load_dotenv + +import litellm.proxy +import litellm.proxy.proxy_server + +load_dotenv() +import io +import os + +# this file is to test litellm/proxy + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +import logging + +from litellm.proxy.proxy_server import ProxyConfig + + +@pytest.mark.asyncio +async def test_basic_reading_configs_from_files(): + """ + Test that the config is read correctly from the files in the example_config_yaml folder + """ + proxy_config_instance = ProxyConfig() + current_path = os.path.dirname(os.path.abspath(__file__)) + example_config_yaml_path = os.path.join(current_path, "example_config_yaml") + + # get all the files from example_config_yaml + files = os.listdir(example_config_yaml_path) + print(files) + + for file in files: + config_path = os.path.join(example_config_yaml_path, file) + config = await proxy_config_instance.get_config(config_file_path=config_path) + print(config) + + +@pytest.mark.asyncio +async def test_read_config_from_bad_file_path(): + """ + Raise an exception if the file path is not valid + """ + proxy_config_instance = ProxyConfig() + config_path = "non-existent-file.yaml" + with pytest.raises(Exception): + config = await proxy_config_instance.get_config(config_file_path=config_path) + + +@pytest.mark.asyncio +async def test_read_config_file_with_os_environ_vars(): + """ + Ensures os.environ variables are read correctly from config.yaml + Following vars are set as os.environ variables in the config.yaml file + - DEFAULT_USER_ROLE + - AWS_ACCESS_KEY_ID + - AWS_SECRET_ACCESS_KEY + - AZURE_GPT_4O + - FIREWORKS + """ + + _env_vars_for_testing = { + "DEFAULT_USER_ROLE": "admin", + "AWS_ACCESS_KEY_ID": "1234567890", + "AWS_SECRET_ACCESS_KEY": "1234567890", + "AZURE_GPT_4O": "1234567890", + "FIREWORKS": "1234567890", + } + + _old_env_vars = {} + for key, value in _env_vars_for_testing.items(): + if key in os.environ: + _old_env_vars[key] = os.environ.get(key) + os.environ[key] = value + + # Read config + proxy_config_instance = ProxyConfig() + current_path = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join( + current_path, "example_config_yaml", "config_with_env_vars.yaml" + ) + config = await proxy_config_instance.get_config(config_file_path=config_path) + print(config) + + # Add assertions + assert ( + config["litellm_settings"]["default_internal_user_params"]["user_role"] + == "admin" + ) + assert ( + config["litellm_settings"]["s3_callback_params"]["s3_aws_access_key_id"] + == "1234567890" + ) + assert ( + config["litellm_settings"]["s3_callback_params"]["s3_aws_secret_access_key"] + == "1234567890" + ) + + for model in config["model_list"]: + if "azure" in model["litellm_params"]["model"]: + assert model["litellm_params"]["api_key"] == "1234567890" + elif "fireworks" in model["litellm_params"]["model"]: + assert model["litellm_params"]["api_key"] == "1234567890" + + # cleanup + for key, value in _env_vars_for_testing.items(): + if key in _old_env_vars: + os.environ[key] = _old_env_vars[key] + else: + del os.environ[key]