Merge remote-tracking branch 'src/main'

This commit is contained in:
Sébastien Campion 2024-01-27 19:24:35 +01:00
commit 4dd18b553a
29 changed files with 550 additions and 170 deletions

View file

@ -198,7 +198,14 @@ class ProxyLogging:
max_budget = user_info["max_budget"]
spend = user_info["spend"]
user_email = user_info["user_email"]
user_info = f"""\nUser ID: {user_id}\nMax Budget: {max_budget}\nSpend: {spend}\nUser Email: {user_email}"""
user_info = f"""\nUser ID: {user_id}\nMax Budget: ${max_budget}\nSpend: ${spend}\nUser Email: {user_email}"""
elif type == "token_budget":
token_info = dict(user_info)
token = token_info["token"]
spend = token_info["spend"]
max_budget = token_info["max_budget"]
user_id = token_info["user_id"]
user_info = f"""\nToken: {token}\nSpend: ${spend}\nMax Budget: ${max_budget}\nUser ID: {user_id}"""
else:
user_info = str(user_info)
# percent of max_budget left to spend
@ -814,7 +821,13 @@ class PrismaClient:
Allow user to delete a key(s)
"""
try:
hashed_tokens = [self.hash_token(token=token) for token in tokens]
hashed_tokens = []
for token in tokens:
if isinstance(token, str) and token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
else:
hashed_token = token
hashed_tokens.append(hashed_token)
await self.db.litellm_verificationtoken.delete_many(
where={"token": {"in": hashed_tokens}}
)
@ -1060,10 +1073,11 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion")
cache_hit = kwargs.get("cache_hit", False)
usage = response_obj["usage"]
if type(usage) == litellm.Usage:
usage = dict(usage)
id = response_obj.get("id", str(uuid.uuid4()))
api_key = metadata.get("user_api_key", "")
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
@ -1091,10 +1105,11 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
"endTime": end_time,
"model": kwargs.get("model", ""),
"user": kwargs.get("user", ""),
"modelParameters": optional_params,
"usage": usage,
"metadata": metadata,
"cache_key": cache_key,
"total_tokens": usage.get("total_tokens", 0),
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
}
json_fields = [
@ -1119,8 +1134,6 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
payload[param] = payload[param].model_dump_json()
if type(payload[param]) == litellm.EmbeddingResponse:
payload[param] = payload[param].model_dump_json()
elif type(payload[param]) == litellm.Usage:
payload[param] = payload[param].model_dump_json()
else:
payload[param] = json.dumps(payload[param])