forked from phoenix/litellm-mirror
Improve code formatting and allow configurable litellm config path via environment variable.
This commit is contained in:
parent
a9ebf1b6ab
commit
52fdfe5819
2 changed files with 16 additions and 14 deletions
|
@ -9,7 +9,7 @@ import operator
|
||||||
config_filename = "litellm.secrets.toml"
|
config_filename = "litellm.secrets.toml"
|
||||||
# Using appdirs to determine user-specific config path
|
# Using appdirs to determine user-specific config path
|
||||||
config_dir = appdirs.user_config_dir("litellm")
|
config_dir = appdirs.user_config_dir("litellm")
|
||||||
user_config_path = os.path.join(config_dir, config_filename)
|
user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename))
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
from importlib import resources
|
from importlib import resources
|
||||||
|
|
|
@ -18,7 +18,8 @@ except ImportError:
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"])
|
subprocess.check_call(
|
||||||
|
[sys.executable, "-m", "pip", "install", "uvicorn", "fastapi", "tomli", "appdirs", "tomli-w", "backoff"])
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import fastapi
|
import fastapi
|
||||||
import tomli as tomllib
|
import tomli as tomllib
|
||||||
|
@ -26,9 +27,9 @@ except ImportError:
|
||||||
import tomli_w
|
import tomli_w
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .llm import litellm_completion
|
from .llm import litellm_completion
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
from llm import litellm_completion # type: ignore
|
from llm import litellm_completion # type: ignore
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
@ -105,7 +106,7 @@ model_router = litellm.Router()
|
||||||
config_filename = "litellm.secrets.toml"
|
config_filename = "litellm.secrets.toml"
|
||||||
config_dir = os.getcwd()
|
config_dir = os.getcwd()
|
||||||
config_dir = appdirs.user_config_dir("litellm")
|
config_dir = appdirs.user_config_dir("litellm")
|
||||||
user_config_path = os.path.join(config_dir, config_filename)
|
user_config_path = os.getenv("LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename))
|
||||||
log_file = 'api_log.json'
|
log_file = 'api_log.json'
|
||||||
|
|
||||||
|
|
||||||
|
@ -184,7 +185,7 @@ def save_params_to_config(data: dict):
|
||||||
|
|
||||||
|
|
||||||
def load_config():
|
def load_config():
|
||||||
try:
|
try:
|
||||||
global user_config, user_api_base, user_max_tokens, user_temperature, user_model, local_logging
|
global user_config, user_api_base, user_max_tokens, user_temperature, user_model, local_logging
|
||||||
# As the .env file is typically much simpler in structure, we use load_dotenv here directly
|
# As the .env file is typically much simpler in structure, we use load_dotenv here directly
|
||||||
with open(user_config_path, "rb") as f:
|
with open(user_config_path, "rb") as f:
|
||||||
|
@ -199,9 +200,9 @@ def load_config():
|
||||||
litellm.add_function_to_prompt = user_config["general"].get("add_function_to_prompt",
|
litellm.add_function_to_prompt = user_config["general"].get("add_function_to_prompt",
|
||||||
True) # by default add function to prompt if unsupported by provider
|
True) # by default add function to prompt if unsupported by provider
|
||||||
litellm.drop_params = user_config["general"].get("drop_params",
|
litellm.drop_params = user_config["general"].get("drop_params",
|
||||||
True) # by default drop params if unsupported by provider
|
True) # by default drop params if unsupported by provider
|
||||||
litellm.model_fallbacks = user_config["general"].get("fallbacks",
|
litellm.model_fallbacks = user_config["general"].get("fallbacks",
|
||||||
None) # fallback models in case initial completion call fails
|
None) # fallback models in case initial completion call fails
|
||||||
default_model = user_config["general"].get("default_model", None) # route all requests to this model.
|
default_model = user_config["general"].get("default_model", None) # route all requests to this model.
|
||||||
|
|
||||||
local_logging = user_config["general"].get("local_logging", True)
|
local_logging = user_config["general"].get("local_logging", True)
|
||||||
|
@ -215,10 +216,10 @@ def load_config():
|
||||||
if user_model in user_config["model"]:
|
if user_model in user_config["model"]:
|
||||||
model_config = user_config["model"][user_model]
|
model_config = user_config["model"][user_model]
|
||||||
model_list = []
|
model_list = []
|
||||||
for model in user_config["model"]:
|
for model in user_config["model"]:
|
||||||
if "model_list" in user_config["model"][model]:
|
if "model_list" in user_config["model"][model]:
|
||||||
model_list.extend(user_config["model"][model]["model_list"])
|
model_list.extend(user_config["model"][model]["model_list"])
|
||||||
if len(model_list) > 0:
|
if len(model_list) > 0:
|
||||||
model_router.set_model_list(model_list=model_list)
|
model_router.set_model_list(model_list=model_list)
|
||||||
|
|
||||||
print_verbose(f"user_config: {user_config}")
|
print_verbose(f"user_config: {user_config}")
|
||||||
|
@ -254,7 +255,7 @@ def load_config():
|
||||||
},
|
},
|
||||||
final_prompt_value=model_prompt_template.get("MODEL_POST_PROMPT", ""),
|
final_prompt_value=model_prompt_template.get("MODEL_POST_PROMPT", ""),
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -271,8 +272,8 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke
|
||||||
if api_base: # model-specific param
|
if api_base: # model-specific param
|
||||||
user_api_base = api_base
|
user_api_base = api_base
|
||||||
dynamic_config[user_model]["api_base"] = api_base
|
dynamic_config[user_model]["api_base"] = api_base
|
||||||
if api_version:
|
if api_version:
|
||||||
os.environ["AZURE_API_VERSION"] = api_version # set this for azure - litellm can read this from the env
|
os.environ["AZURE_API_VERSION"] = api_version # set this for azure - litellm can read this from the env
|
||||||
if max_tokens: # model-specific param
|
if max_tokens: # model-specific param
|
||||||
user_max_tokens = max_tokens
|
user_max_tokens = max_tokens
|
||||||
dynamic_config[user_model]["max_tokens"] = max_tokens
|
dynamic_config[user_model]["max_tokens"] = max_tokens
|
||||||
|
@ -290,7 +291,7 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke
|
||||||
if max_budget: # litellm-specific param
|
if max_budget: # litellm-specific param
|
||||||
litellm.max_budget = max_budget
|
litellm.max_budget = max_budget
|
||||||
dynamic_config["general"]["max_budget"] = max_budget
|
dynamic_config["general"]["max_budget"] = max_budget
|
||||||
if debug: # litellm-specific param
|
if debug: # litellm-specific param
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
if save:
|
if save:
|
||||||
save_params_to_config(dynamic_config)
|
save_params_to_config(dynamic_config)
|
||||||
|
@ -300,6 +301,7 @@ def initialize(model, alias, api_base, api_version, debug, temperature, max_toke
|
||||||
user_telemetry = telemetry
|
user_telemetry = telemetry
|
||||||
usage_telemetry(feature="local_proxy_server")
|
usage_telemetry(feature="local_proxy_server")
|
||||||
|
|
||||||
|
|
||||||
def track_cost_callback(
|
def track_cost_callback(
|
||||||
kwargs, # kwargs to completion
|
kwargs, # kwargs to completion
|
||||||
completion_response, # response from completion
|
completion_response, # response from completion
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue