fix(routing.py): update token usage on streaming

This commit is contained in:
Krrish Dholakia 2023-11-20 14:19:14 -08:00
parent 0422bba38d
commit 1976d0f7d6
4 changed files with 14 additions and 167 deletions

View file

@ -78,12 +78,12 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
global feature_telemetry
args = locals()
if local:
from proxy_server import app, save_worker_config, usage_telemetry, add_keys_to_config
from proxy_server import app, save_worker_config, usage_telemetry
else:
try:
from .proxy_server import app, save_worker_config, usage_telemetry, add_keys_to_config
from .proxy_server import app, save_worker_config, usage_telemetry
except ImportError as e:
from proxy_server import app, save_worker_config, usage_telemetry, add_keys_to_config
from proxy_server import app, save_worker_config, usage_telemetry
feature_telemetry = usage_telemetry
if logs is not None:
if logs == 0: # default to 1
@ -105,13 +105,6 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
except:
print("LiteLLM: No logs saved!")
return
if add_key:
key_name, key_value = add_key.split("=")
add_keys_to_config(key_name, key_value)
with open(user_config_path) as f:
print(f.read())
print("\033[1;32mDone successfully\033[0m")
return
if model and "ollama" in model:
run_ollama_serve()
if test != False:

View file

@ -184,58 +184,6 @@ async def user_api_key_auth(request: Request):
detail={"error": "invalid user key"},
)
def add_keys_to_config(key, value):
#### DEPRECATED #### - this uses the older .toml config approach, which has been deprecated for config.yaml
# Check if file exists
if os.path.exists(user_config_path):
# Load existing file
with open(user_config_path, "rb") as f:
config = tomllib.load(f)
else:
# File doesn't exist, create empty config
config = {}
# Add new key
config.setdefault("keys", {})[key] = value
# Write config to file
with open(user_config_path, "wb") as f:
tomli_w.dump(config, f)
def save_params_to_config(data: dict):
#### DEPRECATED #### - this uses the older .toml config approach, which has been deprecated for config.yaml
# Check if file exists
if os.path.exists(user_config_path):
# Load existing file
with open(user_config_path, "rb") as f:
config = tomllib.load(f)
else:
# File doesn't exist, create empty config
config = {}
config.setdefault("general", {})
## general config
general_settings = data["general"]
for key, value in general_settings.items():
config["general"][key] = value
## model-specific config
config.setdefault("model", {})
config["model"].setdefault(user_model, {})
user_model_config = data[user_model]
model_key = model_key = user_model_config.pop("alias", user_model)
config["model"].setdefault(model_key, {})
for key, value in user_model_config.items():
config["model"][model_key][key] = value
# Write config to file
with open(user_config_path, "wb") as f:
tomli_w.dump(config, f)
def prisma_setup(database_url: Optional[str]):
global prisma_client
if database_url:
@ -285,6 +233,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if litellm_settings:
for key, value in litellm_settings.items():
setattr(litellm, key, value)
print(f"key: {key}; value: {value}")
print(f"success callbacks: {litellm.success_callback}")
## MODEL LIST
model_list = config.get('model_list', None)
@ -293,7 +243,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m")
for model in model_list:
print(f"\033[32m {model.get('model_name', '')}\033[0m")
print()
return router, model_list, server_settings
@ -341,107 +290,6 @@ async def generate_key_cli_task(duration_str):
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
await task
def load_config():
#### DEPRECATED ####
try:
global user_config, user_api_base, user_max_tokens, user_temperature, user_model, local_logging, llm_model_list, llm_router, server_settings
# Get the file extension
file_extension = os.path.splitext(user_config_path)[1]
if file_extension.lower() == ".toml":
# As the .env file is typically much simpler in structure, we use load_dotenv here directly
with open(user_config_path, "rb") as f:
user_config = tomllib.load(f)
## load keys
if "keys" in user_config:
for key in user_config["keys"]:
os.environ[key] = user_config["keys"][
key
] # litellm can read keys from the environment
## settings
if "general" in user_config:
litellm.add_function_to_prompt = user_config["general"].get(
"add_function_to_prompt", True
) # by default add function to prompt if unsupported by provider
litellm.drop_params = user_config["general"].get(
"drop_params", True
) # by default drop params if unsupported by provider
litellm.model_fallbacks = user_config["general"].get(
"fallbacks", None
) # fallback models in case initial completion call fails
default_model = user_config["general"].get(
"default_model", None
) # route all requests to this model.
local_logging = user_config["general"].get("local_logging", True)
if user_model is None: # `litellm --model <model-name>`` > default_model.
user_model = default_model
## load model config - to set this run `litellm --config`
model_config = None
if "model" in user_config:
if user_model in user_config["model"]:
model_config = user_config["model"][user_model]
model_list = []
for model in user_config["model"]:
if "model_list" in user_config["model"][model]:
model_list.extend(user_config["model"][model]["model_list"])
print_verbose(f"user_config: {user_config}")
print_verbose(f"model_config: {model_config}")
print_verbose(f"user_model: {user_model}")
if model_config is None:
return
user_max_tokens = model_config.get("max_tokens", None)
user_temperature = model_config.get("temperature", None)
user_api_base = model_config.get("api_base", None)
## custom prompt template
if "prompt_template" in model_config:
model_prompt_template = model_config["prompt_template"]
if (
len(model_prompt_template.keys()) > 0
): # if user has initialized this at all
litellm.register_prompt_template(
model=user_model,
initial_prompt_value=model_prompt_template.get(
"MODEL_PRE_PROMPT", ""
),
roles={
"system": {
"pre_message": model_prompt_template.get(
"MODEL_SYSTEM_MESSAGE_START_TOKEN", ""
),
"post_message": model_prompt_template.get(
"MODEL_SYSTEM_MESSAGE_END_TOKEN", ""
),
},
"user": {
"pre_message": model_prompt_template.get(
"MODEL_USER_MESSAGE_START_TOKEN", ""
),
"post_message": model_prompt_template.get(
"MODEL_USER_MESSAGE_END_TOKEN", ""
),
},
"assistant": {
"pre_message": model_prompt_template.get(
"MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""
),
"post_message": model_prompt_template.get(
"MODEL_ASSISTANT_MESSAGE_END_TOKEN", ""
),
},
},
final_prompt_value=model_prompt_template.get(
"MODEL_POST_PROMPT", ""
),
)
except:
pass
def save_worker_config(**data):
import json

View file

@ -331,8 +331,13 @@ class Router:
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
if custom_llm_provider:
model_name = f"{custom_llm_provider}/{model_name}"
total_tokens = completion_response['usage']['total_tokens']
self._set_deployment_usage(model_name, total_tokens)
if kwargs["stream"] is True:
if kwargs.get("complete_streaming_response"):
total_tokens = kwargs.get("complete_streaming_response")['usage']['total_tokens']
self._set_deployment_usage(model_name, total_tokens)
else:
total_tokens = completion_response['usage']['total_tokens']
self._set_deployment_usage(model_name, total_tokens)
def get_usage_based_available_deployment(self,
model: str,

View file

@ -501,7 +501,8 @@ class Logging:
"messages": self.messages,
"optional_params": self.optional_params,
"litellm_params": self.litellm_params,
"start_time": self.start_time
"start_time": self.start_time,
"stream": self.stream
}
def pre_call(self, input, api_key, model=None, additional_args={}):