forked from phoenix/litellm-mirror
fix(routing.py): update token usage on streaming
This commit is contained in:
parent
0422bba38d
commit
1976d0f7d6
4 changed files with 14 additions and 167 deletions
|
@ -78,12 +78,12 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
||||||
global feature_telemetry
|
global feature_telemetry
|
||||||
args = locals()
|
args = locals()
|
||||||
if local:
|
if local:
|
||||||
from proxy_server import app, save_worker_config, usage_telemetry, add_keys_to_config
|
from proxy_server import app, save_worker_config, usage_telemetry
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
from .proxy_server import app, save_worker_config, usage_telemetry, add_keys_to_config
|
from .proxy_server import app, save_worker_config, usage_telemetry
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
from proxy_server import app, save_worker_config, usage_telemetry, add_keys_to_config
|
from proxy_server import app, save_worker_config, usage_telemetry
|
||||||
feature_telemetry = usage_telemetry
|
feature_telemetry = usage_telemetry
|
||||||
if logs is not None:
|
if logs is not None:
|
||||||
if logs == 0: # default to 1
|
if logs == 0: # default to 1
|
||||||
|
@ -105,13 +105,6 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
||||||
except:
|
except:
|
||||||
print("LiteLLM: No logs saved!")
|
print("LiteLLM: No logs saved!")
|
||||||
return
|
return
|
||||||
if add_key:
|
|
||||||
key_name, key_value = add_key.split("=")
|
|
||||||
add_keys_to_config(key_name, key_value)
|
|
||||||
with open(user_config_path) as f:
|
|
||||||
print(f.read())
|
|
||||||
print("\033[1;32mDone successfully\033[0m")
|
|
||||||
return
|
|
||||||
if model and "ollama" in model:
|
if model and "ollama" in model:
|
||||||
run_ollama_serve()
|
run_ollama_serve()
|
||||||
if test != False:
|
if test != False:
|
||||||
|
|
|
@ -184,58 +184,6 @@ async def user_api_key_auth(request: Request):
|
||||||
detail={"error": "invalid user key"},
|
detail={"error": "invalid user key"},
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_keys_to_config(key, value):
|
|
||||||
#### DEPRECATED #### - this uses the older .toml config approach, which has been deprecated for config.yaml
|
|
||||||
# Check if file exists
|
|
||||||
if os.path.exists(user_config_path):
|
|
||||||
# Load existing file
|
|
||||||
with open(user_config_path, "rb") as f:
|
|
||||||
config = tomllib.load(f)
|
|
||||||
else:
|
|
||||||
# File doesn't exist, create empty config
|
|
||||||
config = {}
|
|
||||||
|
|
||||||
# Add new key
|
|
||||||
config.setdefault("keys", {})[key] = value
|
|
||||||
|
|
||||||
# Write config to file
|
|
||||||
with open(user_config_path, "wb") as f:
|
|
||||||
tomli_w.dump(config, f)
|
|
||||||
|
|
||||||
|
|
||||||
def save_params_to_config(data: dict):
|
|
||||||
#### DEPRECATED #### - this uses the older .toml config approach, which has been deprecated for config.yaml
|
|
||||||
# Check if file exists
|
|
||||||
if os.path.exists(user_config_path):
|
|
||||||
# Load existing file
|
|
||||||
with open(user_config_path, "rb") as f:
|
|
||||||
config = tomllib.load(f)
|
|
||||||
else:
|
|
||||||
# File doesn't exist, create empty config
|
|
||||||
config = {}
|
|
||||||
|
|
||||||
config.setdefault("general", {})
|
|
||||||
|
|
||||||
## general config
|
|
||||||
general_settings = data["general"]
|
|
||||||
|
|
||||||
for key, value in general_settings.items():
|
|
||||||
config["general"][key] = value
|
|
||||||
|
|
||||||
## model-specific config
|
|
||||||
config.setdefault("model", {})
|
|
||||||
config["model"].setdefault(user_model, {})
|
|
||||||
|
|
||||||
user_model_config = data[user_model]
|
|
||||||
model_key = model_key = user_model_config.pop("alias", user_model)
|
|
||||||
config["model"].setdefault(model_key, {})
|
|
||||||
for key, value in user_model_config.items():
|
|
||||||
config["model"][model_key][key] = value
|
|
||||||
|
|
||||||
# Write config to file
|
|
||||||
with open(user_config_path, "wb") as f:
|
|
||||||
tomli_w.dump(config, f)
|
|
||||||
|
|
||||||
def prisma_setup(database_url: Optional[str]):
|
def prisma_setup(database_url: Optional[str]):
|
||||||
global prisma_client
|
global prisma_client
|
||||||
if database_url:
|
if database_url:
|
||||||
|
@ -285,6 +233,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
if litellm_settings:
|
if litellm_settings:
|
||||||
for key, value in litellm_settings.items():
|
for key, value in litellm_settings.items():
|
||||||
setattr(litellm, key, value)
|
setattr(litellm, key, value)
|
||||||
|
print(f"key: {key}; value: {value}")
|
||||||
|
print(f"success callbacks: {litellm.success_callback}")
|
||||||
|
|
||||||
## MODEL LIST
|
## MODEL LIST
|
||||||
model_list = config.get('model_list', None)
|
model_list = config.get('model_list', None)
|
||||||
|
@ -293,7 +243,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m")
|
print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m")
|
||||||
for model in model_list:
|
for model in model_list:
|
||||||
print(f"\033[32m {model.get('model_name', '')}\033[0m")
|
print(f"\033[32m {model.get('model_name', '')}\033[0m")
|
||||||
print()
|
|
||||||
|
|
||||||
return router, model_list, server_settings
|
return router, model_list, server_settings
|
||||||
|
|
||||||
|
@ -341,107 +290,6 @@ async def generate_key_cli_task(duration_str):
|
||||||
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
|
task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str))
|
||||||
await task
|
await task
|
||||||
|
|
||||||
def load_config():
|
|
||||||
#### DEPRECATED ####
|
|
||||||
try:
|
|
||||||
global user_config, user_api_base, user_max_tokens, user_temperature, user_model, local_logging, llm_model_list, llm_router, server_settings
|
|
||||||
|
|
||||||
# Get the file extension
|
|
||||||
file_extension = os.path.splitext(user_config_path)[1]
|
|
||||||
if file_extension.lower() == ".toml":
|
|
||||||
# As the .env file is typically much simpler in structure, we use load_dotenv here directly
|
|
||||||
with open(user_config_path, "rb") as f:
|
|
||||||
user_config = tomllib.load(f)
|
|
||||||
|
|
||||||
## load keys
|
|
||||||
if "keys" in user_config:
|
|
||||||
for key in user_config["keys"]:
|
|
||||||
os.environ[key] = user_config["keys"][
|
|
||||||
key
|
|
||||||
] # litellm can read keys from the environment
|
|
||||||
## settings
|
|
||||||
if "general" in user_config:
|
|
||||||
litellm.add_function_to_prompt = user_config["general"].get(
|
|
||||||
"add_function_to_prompt", True
|
|
||||||
) # by default add function to prompt if unsupported by provider
|
|
||||||
litellm.drop_params = user_config["general"].get(
|
|
||||||
"drop_params", True
|
|
||||||
) # by default drop params if unsupported by provider
|
|
||||||
litellm.model_fallbacks = user_config["general"].get(
|
|
||||||
"fallbacks", None
|
|
||||||
) # fallback models in case initial completion call fails
|
|
||||||
default_model = user_config["general"].get(
|
|
||||||
"default_model", None
|
|
||||||
) # route all requests to this model.
|
|
||||||
|
|
||||||
local_logging = user_config["general"].get("local_logging", True)
|
|
||||||
|
|
||||||
if user_model is None: # `litellm --model <model-name>`` > default_model.
|
|
||||||
user_model = default_model
|
|
||||||
|
|
||||||
## load model config - to set this run `litellm --config`
|
|
||||||
model_config = None
|
|
||||||
if "model" in user_config:
|
|
||||||
if user_model in user_config["model"]:
|
|
||||||
model_config = user_config["model"][user_model]
|
|
||||||
model_list = []
|
|
||||||
for model in user_config["model"]:
|
|
||||||
if "model_list" in user_config["model"][model]:
|
|
||||||
model_list.extend(user_config["model"][model]["model_list"])
|
|
||||||
|
|
||||||
print_verbose(f"user_config: {user_config}")
|
|
||||||
print_verbose(f"model_config: {model_config}")
|
|
||||||
print_verbose(f"user_model: {user_model}")
|
|
||||||
if model_config is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
user_max_tokens = model_config.get("max_tokens", None)
|
|
||||||
user_temperature = model_config.get("temperature", None)
|
|
||||||
user_api_base = model_config.get("api_base", None)
|
|
||||||
|
|
||||||
## custom prompt template
|
|
||||||
if "prompt_template" in model_config:
|
|
||||||
model_prompt_template = model_config["prompt_template"]
|
|
||||||
if (
|
|
||||||
len(model_prompt_template.keys()) > 0
|
|
||||||
): # if user has initialized this at all
|
|
||||||
litellm.register_prompt_template(
|
|
||||||
model=user_model,
|
|
||||||
initial_prompt_value=model_prompt_template.get(
|
|
||||||
"MODEL_PRE_PROMPT", ""
|
|
||||||
),
|
|
||||||
roles={
|
|
||||||
"system": {
|
|
||||||
"pre_message": model_prompt_template.get(
|
|
||||||
"MODEL_SYSTEM_MESSAGE_START_TOKEN", ""
|
|
||||||
),
|
|
||||||
"post_message": model_prompt_template.get(
|
|
||||||
"MODEL_SYSTEM_MESSAGE_END_TOKEN", ""
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"user": {
|
|
||||||
"pre_message": model_prompt_template.get(
|
|
||||||
"MODEL_USER_MESSAGE_START_TOKEN", ""
|
|
||||||
),
|
|
||||||
"post_message": model_prompt_template.get(
|
|
||||||
"MODEL_USER_MESSAGE_END_TOKEN", ""
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"assistant": {
|
|
||||||
"pre_message": model_prompt_template.get(
|
|
||||||
"MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""
|
|
||||||
),
|
|
||||||
"post_message": model_prompt_template.get(
|
|
||||||
"MODEL_ASSISTANT_MESSAGE_END_TOKEN", ""
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
final_prompt_value=model_prompt_template.get(
|
|
||||||
"MODEL_POST_PROMPT", ""
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def save_worker_config(**data):
|
def save_worker_config(**data):
|
||||||
import json
|
import json
|
||||||
|
|
|
@ -331,8 +331,13 @@ class Router:
|
||||||
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
|
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
|
||||||
if custom_llm_provider:
|
if custom_llm_provider:
|
||||||
model_name = f"{custom_llm_provider}/{model_name}"
|
model_name = f"{custom_llm_provider}/{model_name}"
|
||||||
total_tokens = completion_response['usage']['total_tokens']
|
if kwargs["stream"] is True:
|
||||||
self._set_deployment_usage(model_name, total_tokens)
|
if kwargs.get("complete_streaming_response"):
|
||||||
|
total_tokens = kwargs.get("complete_streaming_response")['usage']['total_tokens']
|
||||||
|
self._set_deployment_usage(model_name, total_tokens)
|
||||||
|
else:
|
||||||
|
total_tokens = completion_response['usage']['total_tokens']
|
||||||
|
self._set_deployment_usage(model_name, total_tokens)
|
||||||
|
|
||||||
def get_usage_based_available_deployment(self,
|
def get_usage_based_available_deployment(self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -501,7 +501,8 @@ class Logging:
|
||||||
"messages": self.messages,
|
"messages": self.messages,
|
||||||
"optional_params": self.optional_params,
|
"optional_params": self.optional_params,
|
||||||
"litellm_params": self.litellm_params,
|
"litellm_params": self.litellm_params,
|
||||||
"start_time": self.start_time
|
"start_time": self.start_time,
|
||||||
|
"stream": self.stream
|
||||||
}
|
}
|
||||||
|
|
||||||
def pre_call(self, input, api_key, model=None, additional_args={}):
|
def pre_call(self, input, api_key, model=None, additional_args={}):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue