fix(routing.py): update token usage on streaming

This commit is contained in:
Krrish Dholakia 2023-11-20 14:19:14 -08:00
parent 0422bba38d
commit 1976d0f7d6
4 changed files with 14 additions and 167 deletions

View file

@ -184,58 +184,6 @@ async def user_api_key_auth(request: Request):
detail={"error": "invalid user key"},
)
def add_keys_to_config(key, value):
#### DEPRECATED #### - this uses the older .toml config approach, which has been deprecated for config.yaml
# Check if file exists
if os.path.exists(user_config_path):
# Load existing file
with open(user_config_path, "rb") as f:
config = tomllib.load(f)
else:
# File doesn't exist, create empty config
config = {}
# Add new key
config.setdefault("keys", {})[key] = value
# Write config to file
with open(user_config_path, "wb") as f:
tomli_w.dump(config, f)
def save_params_to_config(data: dict):
#### DEPRECATED #### - this uses the older .toml config approach, which has been deprecated for config.yaml
# Check if file exists
if os.path.exists(user_config_path):
# Load existing file
with open(user_config_path, "rb") as f:
config = tomllib.load(f)
else:
# File doesn't exist, create empty config
config = {}
config.setdefault("general", {})
## general config
general_settings = data["general"]
for key, value in general_settings.items():
config["general"][key] = value
## model-specific config
config.setdefault("model", {})
config["model"].setdefault(user_model, {})
user_model_config = data[user_model]
model_key = model_key = user_model_config.pop("alias", user_model)
config["model"].setdefault(model_key, {})
for key, value in user_model_config.items():
config["model"][model_key][key] = value
# Write config to file
with open(user_config_path, "wb") as f:
tomli_w.dump(config, f)
def prisma_setup(database_url: Optional[str]):
global prisma_client
if database_url:
@ -285,6 +233,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if litellm_settings:
for key, value in litellm_settings.items():
setattr(litellm, key, value)
print(f"key: {key}; value: {value}")
print(f"success callbacks: {litellm.success_callback}")
## MODEL LIST
model_list = config.get('model_list', None)
@ -293,7 +243,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m")
for model in model_list:
print(f"\033[32m {model.get('model_name', '')}\033[0m")
print()
return router, model_list, server_settings
@ -341,107 +290,6 @@ async def generate_key_cli_task(duration_str):
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
await task
def load_config():
#### DEPRECATED ####
try:
global user_config, user_api_base, user_max_tokens, user_temperature, user_model, local_logging, llm_model_list, llm_router, server_settings
# Get the file extension
file_extension = os.path.splitext(user_config_path)[1]
if file_extension.lower() == ".toml":
# As the .env file is typically much simpler in structure, we use load_dotenv here directly
with open(user_config_path, "rb") as f:
user_config = tomllib.load(f)
## load keys
if "keys" in user_config:
for key in user_config["keys"]:
os.environ[key] = user_config["keys"][
key
] # litellm can read keys from the environment
## settings
if "general" in user_config:
litellm.add_function_to_prompt = user_config["general"].get(
"add_function_to_prompt", True
) # by default add function to prompt if unsupported by provider
litellm.drop_params = user_config["general"].get(
"drop_params", True
) # by default drop params if unsupported by provider
litellm.model_fallbacks = user_config["general"].get(
"fallbacks", None
) # fallback models in case initial completion call fails
default_model = user_config["general"].get(
"default_model", None
) # route all requests to this model.
local_logging = user_config["general"].get("local_logging", True)
if user_model is None: # `litellm --model <model-name>`` > default_model.
user_model = default_model
## load model config - to set this run `litellm --config`
model_config = None
if "model" in user_config:
if user_model in user_config["model"]:
model_config = user_config["model"][user_model]
model_list = []
for model in user_config["model"]:
if "model_list" in user_config["model"][model]:
model_list.extend(user_config["model"][model]["model_list"])
print_verbose(f"user_config: {user_config}")
print_verbose(f"model_config: {model_config}")
print_verbose(f"user_model: {user_model}")
if model_config is None:
return
user_max_tokens = model_config.get("max_tokens", None)
user_temperature = model_config.get("temperature", None)
user_api_base = model_config.get("api_base", None)
## custom prompt template
if "prompt_template" in model_config:
model_prompt_template = model_config["prompt_template"]
if (
len(model_prompt_template.keys()) > 0
): # if user has initialized this at all
litellm.register_prompt_template(
model=user_model,
initial_prompt_value=model_prompt_template.get(
"MODEL_PRE_PROMPT", ""
),
roles={
"system": {
"pre_message": model_prompt_template.get(
"MODEL_SYSTEM_MESSAGE_START_TOKEN", ""
),
"post_message": model_prompt_template.get(
"MODEL_SYSTEM_MESSAGE_END_TOKEN", ""
),
},
"user": {
"pre_message": model_prompt_template.get(
"MODEL_USER_MESSAGE_START_TOKEN", ""
),
"post_message": model_prompt_template.get(
"MODEL_USER_MESSAGE_END_TOKEN", ""
),
},
"assistant": {
"pre_message": model_prompt_template.get(
"MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""
),
"post_message": model_prompt_template.get(
"MODEL_ASSISTANT_MESSAGE_END_TOKEN", ""
),
},
},
final_prompt_value=model_prompt_template.get(
"MODEL_POST_PROMPT", ""
),
)
except:
pass
def save_worker_config(**data):
import json