mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge branch 'main' into litellm_map_openai_auth_errors
This commit is contained in:
commit
a0cd4e78fc
13 changed files with 557 additions and 60 deletions
|
@ -42,6 +42,7 @@ jobs:
|
|||
pip install "anyio==3.7.1"
|
||||
pip install "aiodynamo==23.10.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "apscheduler==3.10.4"
|
||||
pip install "PyGithub==1.59.1"
|
||||
- save_cache:
|
||||
paths:
|
||||
|
@ -114,6 +115,25 @@ jobs:
|
|||
pip install "pytest==7.3.1"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install aiohttp
|
||||
pip install openai
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r .circleci/requirements.txt
|
||||
pip install "pytest==7.3.1"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install mypy
|
||||
pip install "google-generativeai>=0.3.2"
|
||||
pip install "google-cloud-aiplatform>=1.38.0"
|
||||
pip install "boto3>=1.28.57"
|
||||
pip install langchain
|
||||
pip install "langfuse>=2.0.0"
|
||||
pip install numpydoc
|
||||
pip install prisma
|
||||
pip install "httpx==0.24.1"
|
||||
pip install "gunicorn==21.2.0"
|
||||
pip install "anyio==3.7.1"
|
||||
pip install "aiodynamo==23.10.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "PyGithub==1.59.1"
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
name: Build Docker image
|
||||
|
|
|
@ -135,6 +135,7 @@ class GenerateKeyRequest(LiteLLMBase):
|
|||
metadata: Optional[dict] = {}
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
budget_duration: Optional[str] = None
|
||||
|
||||
|
||||
class UpdateKeyRequest(LiteLLMBase):
|
||||
|
|
|
@ -98,7 +98,7 @@ def list_models():
|
|||
st.error(f"An error occurred while requesting models: {e}")
|
||||
else:
|
||||
st.warning(
|
||||
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
|
||||
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||
)
|
||||
|
||||
|
||||
|
@ -151,7 +151,7 @@ def create_key():
|
|||
raise e
|
||||
else:
|
||||
st.warning(
|
||||
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
|
||||
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ try:
|
|||
import yaml
|
||||
import orjson
|
||||
import logging
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
|
||||
|
||||
|
@ -73,6 +74,7 @@ from litellm.proxy.utils import (
|
|||
_cache_user_row,
|
||||
send_email,
|
||||
get_logging_payload,
|
||||
reset_budget,
|
||||
)
|
||||
from litellm.proxy.secret_managers.google_kms import load_google_kms
|
||||
import pydantic
|
||||
|
@ -578,7 +580,7 @@ async def track_cost_callback(
|
|||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
if "response_cost" in kwargs:
|
||||
if kwargs.get("response_cost", None) is not None:
|
||||
response_cost = kwargs["response_cost"]
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get(
|
||||
"user_api_key", None
|
||||
|
@ -604,9 +606,13 @@ async def track_cost_callback(
|
|||
end_time=end_time,
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||
)
|
||||
if kwargs["stream"] != True or (
|
||||
kwargs["stream"] == True
|
||||
and kwargs.get("complete_streaming_response") in kwargs
|
||||
):
|
||||
raise Exception(
|
||||
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
|
||||
|
||||
|
@ -703,6 +709,7 @@ async def update_database(
|
|||
valid_token.spend = new_spend
|
||||
user_api_key_cache.set_cache(key=token, value=valid_token)
|
||||
|
||||
### UPDATE SPEND LOGS ###
|
||||
async def _insert_spend_log_to_db():
|
||||
# Helper to generate payload to log
|
||||
verbose_proxy_logger.debug("inserting spend log to db")
|
||||
|
@ -1133,7 +1140,9 @@ async def generate_key_helper_fn(
|
|||
config: dict,
|
||||
spend: float,
|
||||
key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key
|
||||
key_budget_duration: Optional[str] = None,
|
||||
max_budget: Optional[float] = None, # max_budget is used to Budget Per user
|
||||
budget_duration: Optional[str] = None, # max_budget is used to Budget Per user
|
||||
token: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
team_id: Optional[str] = None,
|
||||
|
@ -1178,6 +1187,12 @@ async def generate_key_helper_fn(
|
|||
duration_s = _duration_in_seconds(duration=duration)
|
||||
expires = datetime.utcnow() + timedelta(seconds=duration_s)
|
||||
|
||||
if key_budget_duration is None: # one-time budget
|
||||
key_reset_at = None
|
||||
else:
|
||||
duration_s = _duration_in_seconds(duration=key_budget_duration)
|
||||
key_reset_at = datetime.utcnow() + timedelta(seconds=duration_s)
|
||||
|
||||
aliases_json = json.dumps(aliases)
|
||||
config_json = json.dumps(config)
|
||||
metadata_json = json.dumps(metadata)
|
||||
|
@ -1213,6 +1228,8 @@ async def generate_key_helper_fn(
|
|||
"metadata": metadata_json,
|
||||
"tpm_limit": tpm_limit,
|
||||
"rpm_limit": rpm_limit,
|
||||
"budget_duration": key_budget_duration,
|
||||
"budget_reset_at": key_reset_at,
|
||||
}
|
||||
if prisma_client is not None:
|
||||
## CREATE USER (If necessary)
|
||||
|
@ -1533,7 +1550,7 @@ async def startup_event():
|
|||
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"custom_db_client client - Inserting master key {custom_db_client}. Master_key: {master_key}"
|
||||
f"custom_db_client client {custom_db_client}. Master_key: {master_key}"
|
||||
)
|
||||
if custom_db_client is not None and master_key is not None:
|
||||
# add master key to db
|
||||
|
@ -1541,6 +1558,11 @@ async def startup_event():
|
|||
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
||||
)
|
||||
|
||||
### START BUDGET SCHEDULER ###
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(reset_budget, "interval", seconds=10, args=[prisma_client])
|
||||
scheduler.start()
|
||||
|
||||
|
||||
#### API ENDPOINTS ####
|
||||
@router.get(
|
||||
|
@ -2221,11 +2243,13 @@ async def generate_key_fn(
|
|||
if "max_budget" in data_json:
|
||||
data_json["key_max_budget"] = data_json.pop("max_budget", None)
|
||||
|
||||
|
||||
if "budget_duration" in data_json:
|
||||
data_json["key_budget_duration"] = data_json.pop("budget_duration", None)
|
||||
|
||||
response = await generate_key_helper_fn(**data_json)
|
||||
return GenerateKeyResponse(
|
||||
key=response["token"],
|
||||
expires=response["expires"],
|
||||
user_id=response["user_id"],
|
||||
key=response["token"], expires=response["expires"], user_id=response["user_id"]
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
|
@ -2244,6 +2268,7 @@ async def generate_key_fn(
|
|||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
||||
|
@ -2367,6 +2392,94 @@ async def info_key_fn(
|
|||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/spend/keys",
|
||||
tags=["Budget & Spend Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def spend_key_fn():
|
||||
"""
|
||||
View all keys created, ordered by spend
|
||||
|
||||
Example Request:
|
||||
```
|
||||
curl -X GET "http://0.0.0.0:8000/spend/keys" \
|
||||
-H "Authorization: Bearer sk-1234"
|
||||
```
|
||||
"""
|
||||
global prisma_client
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise Exception(
|
||||
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
|
||||
key_info = await prisma_client.get_data(table_name="key", query_type="find_all")
|
||||
|
||||
return key_info
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/spend/logs",
|
||||
tags=["Budget & Spend Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def view_spend_logs(
|
||||
request_id: Optional[str] = fastapi.Query(
|
||||
default=None,
|
||||
description="request_id to get spend logs for specific request_id. If none passed then pass spend logs for all requests",
|
||||
),
|
||||
):
|
||||
"""
|
||||
View all spend logs, if request_id is provided, only logs for that request_id will be returned
|
||||
|
||||
Example Request for all logs
|
||||
```
|
||||
curl -X GET "http://0.0.0.0:8000/spend/logs" \
|
||||
-H "Authorization: Bearer sk-1234"
|
||||
```
|
||||
|
||||
Example Request for specific request_id
|
||||
```
|
||||
curl -X GET "http://0.0.0.0:8000/spend/logs?request_id=chatcmpl-6dcb2540-d3d7-4e49-bb27-291f863f112e" \
|
||||
-H "Authorization: Bearer sk-1234"
|
||||
```
|
||||
"""
|
||||
global prisma_client
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise Exception(
|
||||
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
spend_logs = []
|
||||
if request_id is not None:
|
||||
spend_log = await prisma_client.get_data(
|
||||
table_name="spend",
|
||||
query_type="find_unique",
|
||||
request_id=request_id,
|
||||
)
|
||||
return [spend_log]
|
||||
else:
|
||||
spend_logs = await prisma_client.get_data(
|
||||
table_name="spend", query_type="find_all"
|
||||
)
|
||||
return spend_logs
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": str(e)},
|
||||
)
|
||||
|
||||
|
||||
#### USER MANAGEMENT ####
|
||||
@router.post(
|
||||
"/user/new",
|
||||
|
|
|
@ -34,6 +34,8 @@ model LiteLLM_VerificationToken {
|
|||
tpm_limit BigInt?
|
||||
rpm_limit BigInt?
|
||||
max_budget Float? @default(0.0)
|
||||
budget_duration String?
|
||||
budget_reset_at DateTime?
|
||||
}
|
||||
|
||||
model LiteLLM_Config {
|
||||
|
|
|
@ -14,10 +14,10 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
from litellm.proxy.db.base_client import CustomDB
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from fastapi import HTTPException, status
|
||||
import smtplib
|
||||
import smtplib, re
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
|
@ -361,8 +361,11 @@ class PrismaClient:
|
|||
self,
|
||||
token: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
table_name: Optional[Literal["user", "key", "config"]] = None,
|
||||
request_id: Optional[str] = None,
|
||||
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
|
||||
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
||||
expires: Optional[datetime] = None,
|
||||
reset_at: Optional[datetime] = None,
|
||||
):
|
||||
try:
|
||||
print_verbose("PrismaClient: get_data")
|
||||
|
@ -391,6 +394,28 @@ class PrismaClient:
|
|||
for r in response:
|
||||
if isinstance(r.expires, datetime):
|
||||
r.expires = r.expires.isoformat()
|
||||
elif (
|
||||
query_type == "find_all"
|
||||
and expires is not None
|
||||
and reset_at is not None
|
||||
):
|
||||
response = await self.db.litellm_verificationtoken.find_many(
|
||||
where={ # type:ignore
|
||||
"OR": [
|
||||
{"expires": None},
|
||||
{"expires": {"gt": expires}},
|
||||
],
|
||||
"budget_reset_at": {"lt": reset_at},
|
||||
}
|
||||
)
|
||||
if response is not None and len(response) > 0:
|
||||
for r in response:
|
||||
if isinstance(r.expires, datetime):
|
||||
r.expires = r.expires.isoformat()
|
||||
elif query_type == "find_all":
|
||||
response = await self.db.litellm_verificationtoken.find_many(
|
||||
order={"spend": "desc"},
|
||||
)
|
||||
print_verbose(f"PrismaClient: response={response}")
|
||||
if response is not None:
|
||||
return response
|
||||
|
@ -407,6 +432,23 @@ class PrismaClient:
|
|||
}
|
||||
)
|
||||
return response
|
||||
elif table_name == "spend":
|
||||
verbose_proxy_logger.debug(
|
||||
f"PrismaClient: get_data: table_name == 'spend'"
|
||||
)
|
||||
if request_id is not None:
|
||||
response = await self.db.litellm_spendlogs.find_unique( # type: ignore
|
||||
where={
|
||||
"request_id": request_id,
|
||||
}
|
||||
)
|
||||
return response
|
||||
else:
|
||||
response = await self.db.litellm_spendlogs.find_many( # type: ignore
|
||||
order={"startTime": "desc"},
|
||||
)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
||||
import traceback
|
||||
|
@ -517,7 +559,10 @@ class PrismaClient:
|
|||
self,
|
||||
token: Optional[str] = None,
|
||||
data: dict = {},
|
||||
data_list: Optional[List] = None,
|
||||
user_id: Optional[str] = None,
|
||||
query_type: Literal["update", "update_many"] = "update",
|
||||
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
|
||||
):
|
||||
"""
|
||||
Update existing data
|
||||
|
@ -534,7 +579,7 @@ class PrismaClient:
|
|||
where={"token": token}, # type: ignore
|
||||
data={**db_data}, # type: ignore
|
||||
)
|
||||
print_verbose(
|
||||
verbose_proxy_logger.debug(
|
||||
"\033[91m"
|
||||
+ f"DB Token Table update succeeded {response}"
|
||||
+ "\033[0m"
|
||||
|
@ -566,6 +611,33 @@ class PrismaClient:
|
|||
+ "\033[0m"
|
||||
)
|
||||
return {"user_id": user_id, "data": db_data}
|
||||
elif (
|
||||
table_name is not None
|
||||
and table_name == "key"
|
||||
and query_type == "update_many"
|
||||
and data_list is not None
|
||||
and isinstance(data_list, list)
|
||||
):
|
||||
"""
|
||||
Batch write update queries
|
||||
"""
|
||||
batcher = self.db.batch_()
|
||||
for idx, t in enumerate(data_list):
|
||||
# check if plain text or hash
|
||||
if t.token.startswith("sk-"): # type: ignore
|
||||
t.token = self.hash_token(token=t.token) # type: ignore
|
||||
try:
|
||||
data_json = self.jsonify_object(data=t.model_dump())
|
||||
except:
|
||||
data_json = self.jsonify_object(data=t.dict())
|
||||
batcher.litellm_verificationtoken.update(
|
||||
where={"token": t.token}, # type: ignore
|
||||
data={**data_json}, # type: ignore
|
||||
)
|
||||
await batcher.commit()
|
||||
print_verbose(
|
||||
"\033[91m" + f"DB Token Table update succeeded" + "\033[0m"
|
||||
)
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
|
@ -834,10 +906,15 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
|||
usage = response_obj["usage"]
|
||||
id = response_obj.get("id", str(uuid.uuid4()))
|
||||
api_key = metadata.get("user_api_key", "")
|
||||
if api_key is not None and type(api_key) == str:
|
||||
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
|
||||
# hash the api_key
|
||||
api_key = hash_token(api_key)
|
||||
|
||||
if "headers" in metadata and "authorization" in metadata["headers"]:
|
||||
metadata["headers"].pop(
|
||||
"authorization"
|
||||
) # do not store the original `sk-..` api key in the db
|
||||
|
||||
payload = {
|
||||
"request_id": id,
|
||||
"call_type": call_type,
|
||||
|
@ -886,3 +963,48 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
|||
payload[param] = str(payload[param])
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _duration_in_seconds(duration: str):
|
||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||
if not match:
|
||||
raise ValueError("Invalid duration format")
|
||||
|
||||
value, unit = match.groups()
|
||||
value = int(value)
|
||||
|
||||
if unit == "s":
|
||||
return value
|
||||
elif unit == "m":
|
||||
return value * 60
|
||||
elif unit == "h":
|
||||
return value * 3600
|
||||
elif unit == "d":
|
||||
return value * 86400
|
||||
else:
|
||||
raise ValueError("Unsupported duration unit")
|
||||
|
||||
|
||||
async def reset_budget(prisma_client: PrismaClient):
|
||||
"""
|
||||
Gets all the non-expired keys for a db, which need spend to be reset
|
||||
|
||||
Resets their spend
|
||||
|
||||
Updates db
|
||||
"""
|
||||
if prisma_client is not None:
|
||||
now = datetime.utcnow()
|
||||
keys_to_reset = await prisma_client.get_data(
|
||||
table_name="key", query_type="find_all", expires=now, reset_at=now
|
||||
)
|
||||
|
||||
for key in keys_to_reset:
|
||||
key.spend = 0.0
|
||||
duration_s = _duration_in_seconds(duration=key.budget_duration)
|
||||
key.budget_reset_at = key.budget_reset_at + timedelta(seconds=duration_s)
|
||||
|
||||
if len(keys_to_reset) > 0:
|
||||
await prisma_client.update_data(
|
||||
query_type="update_many", data_list=keys_to_reset, table_name="key"
|
||||
)
|
||||
|
|
108
litellm/utils.py
108
litellm/utils.py
|
@ -1067,10 +1067,14 @@ class Logging:
|
|||
## if model in model cost map - log the response cost
|
||||
## else set cost to None
|
||||
verbose_logger.debug(f"Model={self.model}; result={result}")
|
||||
if result is not None and (
|
||||
isinstance(result, ModelResponse)
|
||||
or isinstance(result, EmbeddingResponse)
|
||||
):
|
||||
if (
|
||||
result is not None
|
||||
and (
|
||||
isinstance(result, ModelResponse)
|
||||
or isinstance(result, EmbeddingResponse)
|
||||
)
|
||||
and self.stream != True
|
||||
): # handle streaming separately
|
||||
try:
|
||||
self.model_call_details["response_cost"] = litellm.completion_cost(
|
||||
completion_response=result,
|
||||
|
@ -1104,6 +1108,12 @@ class Logging:
|
|||
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
|
||||
):
|
||||
verbose_logger.debug(f"Logging Details LiteLLM-Success Call")
|
||||
start_time, end_time, result = self._success_handler_helper_fn(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
result=result,
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
# print(f"original response in success handler: {self.model_call_details['original_response']}")
|
||||
try:
|
||||
verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
|
||||
|
@ -1119,26 +1129,34 @@ class Logging:
|
|||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
self.sync_streaming_chunks,
|
||||
messages=self.model_call_details.get("messages", None),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
except:
|
||||
complete_streaming_response = None
|
||||
else:
|
||||
self.sync_streaming_chunks.append(result)
|
||||
|
||||
if complete_streaming_response:
|
||||
if complete_streaming_response is not None:
|
||||
verbose_logger.debug(
|
||||
f"Logging Details LiteLLM-Success Call streaming complete"
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
try:
|
||||
self.model_call_details["response_cost"] = litellm.completion_cost(
|
||||
completion_response=complete_streaming_response,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
||||
)
|
||||
except litellm.NotFoundError as e:
|
||||
verbose_logger.debug(
|
||||
f"Model={self.model} not found in completion cost map."
|
||||
)
|
||||
self.model_call_details["response_cost"] = None
|
||||
|
||||
start_time, end_time, result = self._success_handler_helper_fn(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
result=result,
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
for callback in litellm.success_callback:
|
||||
try:
|
||||
if callback == "lite_debugger":
|
||||
|
@ -1418,11 +1436,23 @@ class Logging:
|
|||
complete_streaming_response = None
|
||||
else:
|
||||
self.streaming_chunks.append(result)
|
||||
if complete_streaming_response:
|
||||
if complete_streaming_response is not None:
|
||||
print_verbose("Async success callbacks: Got a complete streaming response")
|
||||
self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
try:
|
||||
self.model_call_details["response_cost"] = litellm.completion_cost(
|
||||
completion_response=complete_streaming_response,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
||||
)
|
||||
except litellm.NotFoundError as e:
|
||||
verbose_logger.debug(
|
||||
f"Model={self.model} not found in completion cost map."
|
||||
)
|
||||
self.model_call_details["response_cost"] = None
|
||||
|
||||
for callback in litellm._async_success_callback:
|
||||
try:
|
||||
|
@ -1470,14 +1500,27 @@ class Logging:
|
|||
end_time=end_time,
|
||||
)
|
||||
if callable(callback): # custom logger functions
|
||||
await customLogger.async_log_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
callback_func=callback,
|
||||
)
|
||||
if self.stream:
|
||||
if "complete_streaming_response" in self.model_call_details:
|
||||
await customLogger.async_log_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
],
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
callback_func=callback,
|
||||
)
|
||||
else:
|
||||
await customLogger.async_log_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
callback_func=callback,
|
||||
)
|
||||
if callback == "dynamodb":
|
||||
global dynamoLogger
|
||||
if dynamoLogger is None:
|
||||
|
@ -2867,6 +2910,9 @@ def cost_per_token(
|
|||
|
||||
if model in model_cost_ref:
|
||||
verbose_logger.debug(f"Success: model={model} in model_cost_map")
|
||||
verbose_logger.debug(
|
||||
f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}"
|
||||
)
|
||||
if (
|
||||
model_cost_ref[model].get("input_cost_per_token", None) is not None
|
||||
and model_cost_ref[model].get("output_cost_per_token", None) is not None
|
||||
|
@ -2895,17 +2941,25 @@ def cost_per_token(
|
|||
)
|
||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||
elif model_with_provider in model_cost_ref:
|
||||
print_verbose(f"Looking up model={model_with_provider} in model_cost_map")
|
||||
verbose_logger.debug(
|
||||
f"Looking up model={model_with_provider} in model_cost_map"
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"applying cost={model_cost_ref[model_with_provider]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
|
||||
)
|
||||
prompt_tokens_cost_usd_dollar = (
|
||||
model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"applying cost={model_cost_ref[model_with_provider]['output_cost_per_token']} for completion_tokens={completion_tokens}"
|
||||
)
|
||||
completion_tokens_cost_usd_dollar = (
|
||||
model_cost_ref[model_with_provider]["output_cost_per_token"]
|
||||
* completion_tokens
|
||||
)
|
||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||
elif "ft:gpt-3.5-turbo" in model:
|
||||
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
|
||||
verbose_logger.debug(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
|
||||
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
|
||||
prompt_tokens_cost_usd_dollar = (
|
||||
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
|
||||
|
@ -2916,17 +2970,23 @@ def cost_per_token(
|
|||
)
|
||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||
elif model in litellm.azure_llms:
|
||||
print_verbose(f"Cost Tracking: {model} is an Azure LLM")
|
||||
verbose_logger.debug(f"Cost Tracking: {model} is an Azure LLM")
|
||||
model = litellm.azure_llms[model]
|
||||
verbose_logger.debug(
|
||||
f"applying cost={model_cost_ref[model]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
|
||||
)
|
||||
prompt_tokens_cost_usd_dollar = (
|
||||
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"applying cost={model_cost_ref[model]['output_cost_per_token']} for completion_tokens={completion_tokens}"
|
||||
)
|
||||
completion_tokens_cost_usd_dollar = (
|
||||
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
||||
)
|
||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||
elif model in litellm.azure_embedding_models:
|
||||
print_verbose(f"Cost Tracking: {model} is an Azure Embedding Model")
|
||||
verbose_logger.debug(f"Cost Tracking: {model} is an Azure Embedding Model")
|
||||
model = litellm.azure_embedding_models[model]
|
||||
prompt_tokens_cost_usd_dollar = (
|
||||
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||
|
|
|
@ -25,6 +25,7 @@ backoff = {version = "*", optional = true}
|
|||
pyyaml = {version = "^6.0.1", optional = true}
|
||||
rq = {version = "*", optional = true}
|
||||
orjson = {version = "^3.9.7", optional = true}
|
||||
apscheduler = {version = "^3.10.4", optional = true}
|
||||
streamlit = {version = "^1.29.0", optional = true}
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
@ -36,6 +37,7 @@ proxy = [
|
|||
"pyyaml",
|
||||
"rq",
|
||||
"orjson",
|
||||
"apscheduler"
|
||||
]
|
||||
|
||||
extra_proxy = [
|
||||
|
|
|
@ -16,6 +16,7 @@ async_generator==1.10.0 # for async ollama calls
|
|||
traceloop-sdk==0.5.3 # for open telemetry logging
|
||||
langfuse>=2.6.3 # for langfuse self-hosted logging
|
||||
orjson==3.9.7 # fast /embedding responses
|
||||
apscheduler==3.10.4 # for resetting budget in background
|
||||
### LITELLM PACKAGE DEPENDENCIES
|
||||
python-dotenv>=0.2.0 # for env
|
||||
tiktoken>=0.4.0 # for calculating usage
|
||||
|
|
|
@ -34,6 +34,8 @@ model LiteLLM_VerificationToken {
|
|||
tpm_limit BigInt?
|
||||
rpm_limit BigInt?
|
||||
max_budget Float? @default(0.0)
|
||||
budget_duration String?
|
||||
budget_reset_at DateTime?
|
||||
}
|
||||
|
||||
model LiteLLM_Config {
|
||||
|
@ -43,8 +45,8 @@ model LiteLLM_Config {
|
|||
|
||||
model LiteLLM_SpendLogs {
|
||||
request_id String @unique
|
||||
api_key String @default ("")
|
||||
call_type String
|
||||
api_key String @default ("")
|
||||
spend Float @default(0.0)
|
||||
startTime DateTime // Assuming start_time is a DateTime field
|
||||
endTime DateTime // Assuming end_time is a DateTime field
|
||||
|
@ -56,4 +58,4 @@ model LiteLLM_SpendLogs {
|
|||
usage Json @default("{}")
|
||||
metadata Json @default("{}")
|
||||
cache_hit String @default("")
|
||||
}
|
||||
}
|
|
@ -2,15 +2,22 @@
|
|||
## Tests /key endpoints.
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import asyncio, time
|
||||
import aiohttp
|
||||
from openai import AsyncOpenAI
|
||||
import sys, os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
async def generate_key(session, i):
|
||||
url = "http://0.0.0.0:4000/key/generate"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {
|
||||
"models": ["azure-models"],
|
||||
"models": ["azure-models", "gpt-4"],
|
||||
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||
"duration": None,
|
||||
}
|
||||
|
@ -82,6 +89,35 @@ async def chat_completion(session, key, model="gpt-4"):
|
|||
if status != 200:
|
||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
|
||||
async def chat_completion_streaming(session, key, model="gpt-4"):
|
||||
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000")
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": f"Hello! {time.time()}"},
|
||||
]
|
||||
prompt_tokens = litellm.token_counter(model="gpt-35-turbo", messages=messages)
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
}
|
||||
response = await client.chat.completions.create(**data)
|
||||
|
||||
content = ""
|
||||
async for chunk in response:
|
||||
content += chunk.choices[0].delta.content or ""
|
||||
|
||||
print(f"content: {content}")
|
||||
|
||||
completion_tokens = litellm.token_counter(
|
||||
model="gpt-35-turbo", text=content, count_response_tokens=True
|
||||
)
|
||||
|
||||
return prompt_tokens, completion_tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_update():
|
||||
|
@ -181,3 +217,49 @@ async def test_key_info():
|
|||
random_key = key_gen["key"]
|
||||
status = await get_key_info(session=session, get_key=key, call_key=random_key)
|
||||
assert status == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_info_spend_values():
|
||||
"""
|
||||
- create key
|
||||
- make completion call
|
||||
- assert cost is expected value
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
## Test Spend Update ##
|
||||
# completion
|
||||
# response = await chat_completion(session=session, key=key)
|
||||
# prompt_cost, completion_cost = litellm.cost_per_token(
|
||||
# model="azure/gpt-35-turbo",
|
||||
# prompt_tokens=response["usage"]["prompt_tokens"],
|
||||
# completion_tokens=response["usage"]["completion_tokens"],
|
||||
# )
|
||||
# response_cost = prompt_cost + completion_cost
|
||||
# await asyncio.sleep(5) # allow db log to be updated
|
||||
# key_info = await get_key_info(session=session, get_key=key, call_key=key)
|
||||
# print(
|
||||
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
|
||||
# )
|
||||
# assert response_cost == key_info["info"]["spend"]
|
||||
## streaming
|
||||
key_gen = await generate_key(session=session, i=0)
|
||||
new_key = key_gen["key"]
|
||||
prompt_tokens, completion_tokens = await chat_completion_streaming(
|
||||
session=session, key=new_key
|
||||
)
|
||||
print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
|
||||
prompt_cost, completion_cost = litellm.cost_per_token(
|
||||
model="azure/gpt-35-turbo",
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
response_cost = prompt_cost + completion_cost
|
||||
await asyncio.sleep(5) # allow db log to be updated
|
||||
key_info = await get_key_info(
|
||||
session=session, get_key=new_key, call_key=new_key
|
||||
)
|
||||
print(
|
||||
f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
|
||||
)
|
||||
assert response_cost == key_info["info"]["spend"]
|
||||
|
|
|
@ -68,6 +68,7 @@ async def chat_completion(session, key):
|
|||
|
||||
if status != 200:
|
||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||
return await response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
125
ui/admin.py
125
ui/admin.py
|
@ -6,6 +6,9 @@ from dotenv import load_dotenv
|
|||
load_dotenv()
|
||||
import streamlit as st
|
||||
import base64, os, json, uuid, requests
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import click
|
||||
|
||||
# Replace your_base_url with the actual URL where the proxy auth app is hosted
|
||||
your_base_url = os.getenv("BASE_URL") # Example base URL
|
||||
|
@ -75,7 +78,7 @@ def add_new_model():
|
|||
and st.session_state.get("proxy_key", None) is None
|
||||
):
|
||||
st.warning(
|
||||
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
|
||||
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||
)
|
||||
|
||||
model_name = st.text_input(
|
||||
|
@ -174,10 +177,70 @@ def list_models():
|
|||
st.error(f"An error occurred while requesting models: {e}")
|
||||
else:
|
||||
st.warning(
|
||||
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
|
||||
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||
)
|
||||
|
||||
|
||||
def spend_per_key():
|
||||
import streamlit as st
|
||||
import requests
|
||||
|
||||
# Check if the necessary configuration is available
|
||||
if (
|
||||
st.session_state.get("api_url", None) is not None
|
||||
and st.session_state.get("proxy_key", None) is not None
|
||||
):
|
||||
# Make the GET request
|
||||
try:
|
||||
complete_url = ""
|
||||
if isinstance(st.session_state["api_url"], str) and st.session_state[
|
||||
"api_url"
|
||||
].endswith("/"):
|
||||
complete_url = f"{st.session_state['api_url']}/spend/keys"
|
||||
else:
|
||||
complete_url = f"{st.session_state['api_url']}/spend/keys"
|
||||
response = requests.get(
|
||||
complete_url,
|
||||
headers={"Authorization": f"Bearer {st.session_state['proxy_key']}"},
|
||||
)
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
spend_per_key = response.json()
|
||||
# Create DataFrame
|
||||
spend_df = pd.DataFrame(spend_per_key)
|
||||
|
||||
# Display the spend per key as a graph
|
||||
st.header("Spend ($) per API Key:")
|
||||
top_10_df = spend_df.nlargest(10, "spend")
|
||||
fig = px.bar(
|
||||
top_10_df,
|
||||
x="token",
|
||||
y="spend",
|
||||
title="Top 10 Spend per Key",
|
||||
height=550, # Adjust the height
|
||||
width=1200, # Adjust the width)
|
||||
hover_data=["token", "spend", "user_id", "team_id"],
|
||||
)
|
||||
st.plotly_chart(fig)
|
||||
|
||||
# Display the spend per key as a table
|
||||
st.write("Spend per Key - Full Table:")
|
||||
st.table(spend_df)
|
||||
|
||||
else:
|
||||
st.error(f"Failed to get models. Status code: {response.status_code}")
|
||||
except Exception as e:
|
||||
st.error(f"An error occurred while requesting models: {e}")
|
||||
else:
|
||||
st.warning(
|
||||
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||
)
|
||||
|
||||
|
||||
def spend_per_user():
|
||||
pass
|
||||
|
||||
|
||||
def create_key():
|
||||
import streamlit as st
|
||||
import json, requests, uuid
|
||||
|
@ -187,7 +250,7 @@ def create_key():
|
|||
and st.session_state.get("proxy_key", None) is None
|
||||
):
|
||||
st.warning(
|
||||
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
|
||||
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||
)
|
||||
|
||||
duration = st.text_input("Duration - Can be in (h,m,s)", placeholder="1h")
|
||||
|
@ -235,7 +298,7 @@ def update_config():
|
|||
and st.session_state.get("proxy_key", None) is None
|
||||
):
|
||||
st.warning(
|
||||
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
|
||||
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||
)
|
||||
|
||||
st.markdown("#### Alerting")
|
||||
|
@ -324,19 +387,25 @@ def update_config():
|
|||
raise e
|
||||
|
||||
|
||||
def admin_page(is_admin="NOT_GIVEN"):
|
||||
def admin_page(is_admin="NOT_GIVEN", input_api_url=None, input_proxy_key=None):
|
||||
# Display the form for the admin to set the proxy URL and allowed email subdomain
|
||||
st.set_page_config(
|
||||
layout="wide", # Use "wide" layout for more space
|
||||
)
|
||||
st.header("Admin Configuration")
|
||||
st.session_state.setdefault("is_admin", is_admin)
|
||||
# Add a navigation sidebar
|
||||
st.sidebar.title("Navigation")
|
||||
|
||||
page = st.sidebar.radio(
|
||||
"Go to",
|
||||
(
|
||||
"Connect to Proxy",
|
||||
"View Spend Per Key",
|
||||
"View Spend Per User",
|
||||
"List Models",
|
||||
"Update Config",
|
||||
"Add Models",
|
||||
"List Models",
|
||||
"Create Key",
|
||||
"End-User Auth",
|
||||
),
|
||||
|
@ -344,16 +413,23 @@ def admin_page(is_admin="NOT_GIVEN"):
|
|||
# Display different pages based on navigation selection
|
||||
if page == "Connect to Proxy":
|
||||
# Use text inputs with intermediary variables
|
||||
input_api_url = st.text_input(
|
||||
"Proxy Endpoint",
|
||||
value=st.session_state.get("api_url", ""),
|
||||
placeholder="http://0.0.0.0:8000",
|
||||
)
|
||||
input_proxy_key = st.text_input(
|
||||
"Proxy Key",
|
||||
value=st.session_state.get("proxy_key", ""),
|
||||
placeholder="sk-...",
|
||||
)
|
||||
if input_api_url is None:
|
||||
input_api_url = st.text_input(
|
||||
"Proxy Endpoint",
|
||||
value=st.session_state.get("api_url", ""),
|
||||
placeholder="http://0.0.0.0:8000",
|
||||
)
|
||||
else:
|
||||
st.session_state["api_url"] = input_api_url
|
||||
|
||||
if input_proxy_key is None:
|
||||
input_proxy_key = st.text_input(
|
||||
"Proxy Key",
|
||||
value=st.session_state.get("proxy_key", ""),
|
||||
placeholder="sk-...",
|
||||
)
|
||||
else:
|
||||
st.session_state["proxy_key"] = input_proxy_key
|
||||
# When the "Save" button is clicked, update the session state
|
||||
if st.button("Save"):
|
||||
st.session_state["api_url"] = input_api_url
|
||||
|
@ -369,6 +445,21 @@ def admin_page(is_admin="NOT_GIVEN"):
|
|||
list_models()
|
||||
elif page == "Create Key":
|
||||
create_key()
|
||||
elif page == "View Spend Per Key":
|
||||
spend_per_key()
|
||||
elif page == "View Spend Per User":
|
||||
spend_per_user()
|
||||
|
||||
|
||||
admin_page()
|
||||
# admin_page()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--proxy_endpoint", type=str, help="Proxy Endpoint")
|
||||
@click.option("--proxy_master_key", type=str, help="Proxy Master Key")
|
||||
def main(proxy_endpoint, proxy_master_key):
|
||||
admin_page(input_api_url=proxy_endpoint, input_proxy_key=proxy_master_key)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue