Merge branch 'main' into litellm_map_openai_auth_errors

This commit is contained in:
Ishaan Jaff 2024-01-23 18:31:48 -08:00 committed by GitHub
commit a0cd4e78fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 557 additions and 60 deletions

View file

@ -42,6 +42,7 @@ jobs:
pip install "anyio==3.7.1"
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "apscheduler==3.10.4"
pip install "PyGithub==1.59.1"
- save_cache:
paths:
@ -114,6 +115,25 @@ jobs:
pip install "pytest==7.3.1"
pip install "pytest-asyncio==0.21.1"
pip install aiohttp
pip install openai
python -m pip install --upgrade pip
python -m pip install -r .circleci/requirements.txt
pip install "pytest==7.3.1"
pip install "pytest-asyncio==0.21.1"
pip install mypy
pip install "google-generativeai>=0.3.2"
pip install "google-cloud-aiplatform>=1.38.0"
pip install "boto3>=1.28.57"
pip install langchain
pip install "langfuse>=2.0.0"
pip install numpydoc
pip install prisma
pip install "httpx==0.24.1"
pip install "gunicorn==21.2.0"
pip install "anyio==3.7.1"
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
# Run pytest and generate JUnit XML report
- run:
name: Build Docker image

View file

@ -135,6 +135,7 @@ class GenerateKeyRequest(LiteLLMBase):
metadata: Optional[dict] = {}
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
budget_duration: Optional[str] = None
class UpdateKeyRequest(LiteLLMBase):

View file

@ -98,7 +98,7 @@ def list_models():
st.error(f"An error occurred while requesting models: {e}")
else:
st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
@ -151,7 +151,7 @@ def create_key():
raise e
else:
st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)

View file

@ -19,6 +19,7 @@ try:
import yaml
import orjson
import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler
except ImportError as e:
raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
@ -73,6 +74,7 @@ from litellm.proxy.utils import (
_cache_user_row,
send_email,
get_logging_payload,
reset_budget,
)
from litellm.proxy.secret_managers.google_kms import load_google_kms
import pydantic
@ -578,7 +580,7 @@ async def track_cost_callback(
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {}
user_id = proxy_server_request.get("body", {}).get("user", None)
if "response_cost" in kwargs:
if kwargs.get("response_cost", None) is not None:
response_cost = kwargs["response_cost"]
user_api_key = kwargs["litellm_params"]["metadata"].get(
"user_api_key", None
@ -604,9 +606,13 @@ async def track_cost_callback(
end_time=end_time,
)
else:
raise Exception(
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
)
if kwargs["stream"] != True or (
kwargs["stream"] == True
and kwargs.get("complete_streaming_response") in kwargs
):
raise Exception(
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
)
except Exception as e:
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
@ -703,6 +709,7 @@ async def update_database(
valid_token.spend = new_spend
user_api_key_cache.set_cache(key=token, value=valid_token)
### UPDATE SPEND LOGS ###
async def _insert_spend_log_to_db():
# Helper to generate payload to log
verbose_proxy_logger.debug("inserting spend log to db")
@ -1133,7 +1140,9 @@ async def generate_key_helper_fn(
config: dict,
spend: float,
key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key
key_budget_duration: Optional[str] = None,
max_budget: Optional[float] = None, # max_budget is used to Budget Per user
budget_duration: Optional[str] = None, # max_budget is used to Budget Per user
token: Optional[str] = None,
user_id: Optional[str] = None,
team_id: Optional[str] = None,
@ -1178,6 +1187,12 @@ async def generate_key_helper_fn(
duration_s = _duration_in_seconds(duration=duration)
expires = datetime.utcnow() + timedelta(seconds=duration_s)
if key_budget_duration is None: # one-time budget
key_reset_at = None
else:
duration_s = _duration_in_seconds(duration=key_budget_duration)
key_reset_at = datetime.utcnow() + timedelta(seconds=duration_s)
aliases_json = json.dumps(aliases)
config_json = json.dumps(config)
metadata_json = json.dumps(metadata)
@ -1213,6 +1228,8 @@ async def generate_key_helper_fn(
"metadata": metadata_json,
"tpm_limit": tpm_limit,
"rpm_limit": rpm_limit,
"budget_duration": key_budget_duration,
"budget_reset_at": key_reset_at,
}
if prisma_client is not None:
## CREATE USER (If necessary)
@ -1533,7 +1550,7 @@ async def startup_event():
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
)
verbose_proxy_logger.debug(
f"custom_db_client client - Inserting master key {custom_db_client}. Master_key: {master_key}"
f"custom_db_client client {custom_db_client}. Master_key: {master_key}"
)
if custom_db_client is not None and master_key is not None:
# add master key to db
@ -1541,6 +1558,11 @@ async def startup_event():
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
)
### START BUDGET SCHEDULER ###
scheduler = AsyncIOScheduler()
scheduler.add_job(reset_budget, "interval", seconds=10, args=[prisma_client])
scheduler.start()
#### API ENDPOINTS ####
@router.get(
@ -2221,11 +2243,13 @@ async def generate_key_fn(
if "max_budget" in data_json:
data_json["key_max_budget"] = data_json.pop("max_budget", None)
if "budget_duration" in data_json:
data_json["key_budget_duration"] = data_json.pop("budget_duration", None)
response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse(
key=response["token"],
expires=response["expires"],
user_id=response["user_id"],
key=response["token"], expires=response["expires"], user_id=response["user_id"]
)
except Exception as e:
if isinstance(e, HTTPException):
@ -2244,6 +2268,7 @@ async def generate_key_fn(
code=status.HTTP_400_BAD_REQUEST,
)
@router.post(
"/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
@ -2367,6 +2392,94 @@ async def info_key_fn(
)
@router.get(
"/spend/keys",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def spend_key_fn():
"""
View all keys created, ordered by spend
Example Request:
```
curl -X GET "http://0.0.0.0:8000/spend/keys" \
-H "Authorization: Bearer sk-1234"
```
"""
global prisma_client
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
key_info = await prisma_client.get_data(table_name="key", query_type="find_all")
return key_info
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
@router.get(
"/spend/logs",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def view_spend_logs(
request_id: Optional[str] = fastapi.Query(
default=None,
description="request_id to get spend logs for specific request_id. If none passed then pass spend logs for all requests",
),
):
"""
View all spend logs, if request_id is provided, only logs for that request_id will be returned
Example Request for all logs
```
curl -X GET "http://0.0.0.0:8000/spend/logs" \
-H "Authorization: Bearer sk-1234"
```
Example Request for specific request_id
```
curl -X GET "http://0.0.0.0:8000/spend/logs?request_id=chatcmpl-6dcb2540-d3d7-4e49-bb27-291f863f112e" \
-H "Authorization: Bearer sk-1234"
```
"""
global prisma_client
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
spend_logs = []
if request_id is not None:
spend_log = await prisma_client.get_data(
table_name="spend",
query_type="find_unique",
request_id=request_id,
)
return [spend_log]
else:
spend_logs = await prisma_client.get_data(
table_name="spend", query_type="find_all"
)
return spend_logs
return None
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
#### USER MANAGEMENT ####
@router.post(
"/user/new",

View file

@ -34,6 +34,8 @@ model LiteLLM_VerificationToken {
tpm_limit BigInt?
rpm_limit BigInt?
max_budget Float? @default(0.0)
budget_duration String?
budget_reset_at DateTime?
}
model LiteLLM_Config {

View file

@ -14,10 +14,10 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException, status
import smtplib
import smtplib, re
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime
from datetime import datetime, timedelta
def print_verbose(print_statement):
@ -361,8 +361,11 @@ class PrismaClient:
self,
token: Optional[str] = None,
user_id: Optional[str] = None,
table_name: Optional[Literal["user", "key", "config"]] = None,
request_id: Optional[str] = None,
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
query_type: Literal["find_unique", "find_all"] = "find_unique",
expires: Optional[datetime] = None,
reset_at: Optional[datetime] = None,
):
try:
print_verbose("PrismaClient: get_data")
@ -391,6 +394,28 @@ class PrismaClient:
for r in response:
if isinstance(r.expires, datetime):
r.expires = r.expires.isoformat()
elif (
query_type == "find_all"
and expires is not None
and reset_at is not None
):
response = await self.db.litellm_verificationtoken.find_many(
where={ # type:ignore
"OR": [
{"expires": None},
{"expires": {"gt": expires}},
],
"budget_reset_at": {"lt": reset_at},
}
)
if response is not None and len(response) > 0:
for r in response:
if isinstance(r.expires, datetime):
r.expires = r.expires.isoformat()
elif query_type == "find_all":
response = await self.db.litellm_verificationtoken.find_many(
order={"spend": "desc"},
)
print_verbose(f"PrismaClient: response={response}")
if response is not None:
return response
@ -407,6 +432,23 @@ class PrismaClient:
}
)
return response
elif table_name == "spend":
verbose_proxy_logger.debug(
f"PrismaClient: get_data: table_name == 'spend'"
)
if request_id is not None:
response = await self.db.litellm_spendlogs.find_unique( # type: ignore
where={
"request_id": request_id,
}
)
return response
else:
response = await self.db.litellm_spendlogs.find_many( # type: ignore
order={"startTime": "desc"},
)
return response
except Exception as e:
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
import traceback
@ -517,7 +559,10 @@ class PrismaClient:
self,
token: Optional[str] = None,
data: dict = {},
data_list: Optional[List] = None,
user_id: Optional[str] = None,
query_type: Literal["update", "update_many"] = "update",
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
):
"""
Update existing data
@ -534,7 +579,7 @@ class PrismaClient:
where={"token": token}, # type: ignore
data={**db_data}, # type: ignore
)
print_verbose(
verbose_proxy_logger.debug(
"\033[91m"
+ f"DB Token Table update succeeded {response}"
+ "\033[0m"
@ -566,6 +611,33 @@ class PrismaClient:
+ "\033[0m"
)
return {"user_id": user_id, "data": db_data}
elif (
table_name is not None
and table_name == "key"
and query_type == "update_many"
and data_list is not None
and isinstance(data_list, list)
):
"""
Batch write update queries
"""
batcher = self.db.batch_()
for idx, t in enumerate(data_list):
# check if plain text or hash
if t.token.startswith("sk-"): # type: ignore
t.token = self.hash_token(token=t.token) # type: ignore
try:
data_json = self.jsonify_object(data=t.model_dump())
except:
data_json = self.jsonify_object(data=t.dict())
batcher.litellm_verificationtoken.update(
where={"token": t.token}, # type: ignore
data={**data_json}, # type: ignore
)
await batcher.commit()
print_verbose(
"\033[91m" + f"DB Token Table update succeeded" + "\033[0m"
)
except Exception as e:
asyncio.create_task(
self.proxy_logging_obj.failure_handler(original_exception=e)
@ -834,10 +906,15 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4()))
api_key = metadata.get("user_api_key", "")
if api_key is not None and type(api_key) == str:
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
# hash the api_key
api_key = hash_token(api_key)
if "headers" in metadata and "authorization" in metadata["headers"]:
metadata["headers"].pop(
"authorization"
) # do not store the original `sk-..` api key in the db
payload = {
"request_id": id,
"call_type": call_type,
@ -886,3 +963,48 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
payload[param] = str(payload[param])
return payload
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
if not match:
raise ValueError("Invalid duration format")
value, unit = match.groups()
value = int(value)
if unit == "s":
return value
elif unit == "m":
return value * 60
elif unit == "h":
return value * 3600
elif unit == "d":
return value * 86400
else:
raise ValueError("Unsupported duration unit")
async def reset_budget(prisma_client: PrismaClient):
"""
Gets all the non-expired keys for a db, which need spend to be reset
Resets their spend
Updates db
"""
if prisma_client is not None:
now = datetime.utcnow()
keys_to_reset = await prisma_client.get_data(
table_name="key", query_type="find_all", expires=now, reset_at=now
)
for key in keys_to_reset:
key.spend = 0.0
duration_s = _duration_in_seconds(duration=key.budget_duration)
key.budget_reset_at = key.budget_reset_at + timedelta(seconds=duration_s)
if len(keys_to_reset) > 0:
await prisma_client.update_data(
query_type="update_many", data_list=keys_to_reset, table_name="key"
)

View file

@ -1067,10 +1067,14 @@ class Logging:
## if model in model cost map - log the response cost
## else set cost to None
verbose_logger.debug(f"Model={self.model}; result={result}")
if result is not None and (
isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse)
):
if (
result is not None
and (
isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse)
)
and self.stream != True
): # handle streaming separately
try:
self.model_call_details["response_cost"] = litellm.completion_cost(
completion_response=result,
@ -1104,6 +1108,12 @@ class Logging:
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
):
verbose_logger.debug(f"Logging Details LiteLLM-Success Call")
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time,
end_time=end_time,
result=result,
cache_hit=cache_hit,
)
# print(f"original response in success handler: {self.model_call_details['original_response']}")
try:
verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
@ -1119,26 +1129,34 @@ class Logging:
complete_streaming_response = litellm.stream_chunk_builder(
self.sync_streaming_chunks,
messages=self.model_call_details.get("messages", None),
start_time=start_time,
end_time=end_time,
)
except:
complete_streaming_response = None
else:
self.sync_streaming_chunks.append(result)
if complete_streaming_response:
if complete_streaming_response is not None:
verbose_logger.debug(
f"Logging Details LiteLLM-Success Call streaming complete"
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
try:
self.model_call_details["response_cost"] = litellm.completion_cost(
completion_response=complete_streaming_response,
)
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
)
except litellm.NotFoundError as e:
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
)
self.model_call_details["response_cost"] = None
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time,
end_time=end_time,
result=result,
cache_hit=cache_hit,
)
for callback in litellm.success_callback:
try:
if callback == "lite_debugger":
@ -1418,11 +1436,23 @@ class Logging:
complete_streaming_response = None
else:
self.streaming_chunks.append(result)
if complete_streaming_response:
if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response")
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
try:
self.model_call_details["response_cost"] = litellm.completion_cost(
completion_response=complete_streaming_response,
)
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
)
except litellm.NotFoundError as e:
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
)
self.model_call_details["response_cost"] = None
for callback in litellm._async_success_callback:
try:
@ -1470,14 +1500,27 @@ class Logging:
end_time=end_time,
)
if callable(callback): # custom logger functions
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
callback_func=callback,
)
if self.stream:
if "complete_streaming_response" in self.model_call_details:
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=self.model_call_details[
"complete_streaming_response"
],
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
callback_func=callback,
)
else:
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
callback_func=callback,
)
if callback == "dynamodb":
global dynamoLogger
if dynamoLogger is None:
@ -2867,6 +2910,9 @@ def cost_per_token(
if model in model_cost_ref:
verbose_logger.debug(f"Success: model={model} in model_cost_map")
verbose_logger.debug(
f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}"
)
if (
model_cost_ref[model].get("input_cost_per_token", None) is not None
and model_cost_ref[model].get("output_cost_per_token", None) is not None
@ -2895,17 +2941,25 @@ def cost_per_token(
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model_with_provider in model_cost_ref:
print_verbose(f"Looking up model={model_with_provider} in model_cost_map")
verbose_logger.debug(
f"Looking up model={model_with_provider} in model_cost_map"
)
verbose_logger.debug(
f"applying cost={model_cost_ref[model_with_provider]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
)
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens
)
verbose_logger.debug(
f"applying cost={model_cost_ref[model_with_provider]['output_cost_per_token']} for completion_tokens={completion_tokens}"
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["output_cost_per_token"]
* completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif "ft:gpt-3.5-turbo" in model:
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
verbose_logger.debug(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
prompt_tokens_cost_usd_dollar = (
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
@ -2916,17 +2970,23 @@ def cost_per_token(
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model in litellm.azure_llms:
print_verbose(f"Cost Tracking: {model} is an Azure LLM")
verbose_logger.debug(f"Cost Tracking: {model} is an Azure LLM")
model = litellm.azure_llms[model]
verbose_logger.debug(
f"applying cost={model_cost_ref[model]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
)
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
)
verbose_logger.debug(
f"applying cost={model_cost_ref[model]['output_cost_per_token']} for completion_tokens={completion_tokens}"
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model in litellm.azure_embedding_models:
print_verbose(f"Cost Tracking: {model} is an Azure Embedding Model")
verbose_logger.debug(f"Cost Tracking: {model} is an Azure Embedding Model")
model = litellm.azure_embedding_models[model]
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens

View file

@ -25,6 +25,7 @@ backoff = {version = "*", optional = true}
pyyaml = {version = "^6.0.1", optional = true}
rq = {version = "*", optional = true}
orjson = {version = "^3.9.7", optional = true}
apscheduler = {version = "^3.10.4", optional = true}
streamlit = {version = "^1.29.0", optional = true}
[tool.poetry.extras]
@ -36,6 +37,7 @@ proxy = [
"pyyaml",
"rq",
"orjson",
"apscheduler"
]
extra_proxy = [

View file

@ -16,6 +16,7 @@ async_generator==1.10.0 # for async ollama calls
traceloop-sdk==0.5.3 # for open telemetry logging
langfuse>=2.6.3 # for langfuse self-hosted logging
orjson==3.9.7 # fast /embedding responses
apscheduler==3.10.4 # for resetting budget in background
### LITELLM PACKAGE DEPENDENCIES
python-dotenv>=0.2.0 # for env
tiktoken>=0.4.0 # for calculating usage

View file

@ -34,6 +34,8 @@ model LiteLLM_VerificationToken {
tpm_limit BigInt?
rpm_limit BigInt?
max_budget Float? @default(0.0)
budget_duration String?
budget_reset_at DateTime?
}
model LiteLLM_Config {
@ -43,8 +45,8 @@ model LiteLLM_Config {
model LiteLLM_SpendLogs {
request_id String @unique
api_key String @default ("")
call_type String
api_key String @default ("")
spend Float @default(0.0)
startTime DateTime // Assuming start_time is a DateTime field
endTime DateTime // Assuming end_time is a DateTime field
@ -56,4 +58,4 @@ model LiteLLM_SpendLogs {
usage Json @default("{}")
metadata Json @default("{}")
cache_hit String @default("")
}
}

View file

@ -2,15 +2,22 @@
## Tests /key endpoints.
import pytest
import asyncio
import asyncio, time
import aiohttp
from openai import AsyncOpenAI
import sys, os
sys.path.insert(
0, os.path.abspath("../")
) # Adds the parent directory to the system path
import litellm
async def generate_key(session, i):
url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"models": ["azure-models"],
"models": ["azure-models", "gpt-4"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None,
}
@ -82,6 +89,35 @@ async def chat_completion(session, key, model="gpt-4"):
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
async def chat_completion_streaming(session, key, model="gpt-4"):
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000")
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": f"Hello! {time.time()}"},
]
prompt_tokens = litellm.token_counter(model="gpt-35-turbo", messages=messages)
data = {
"model": model,
"messages": messages,
"stream": True,
}
response = await client.chat.completions.create(**data)
content = ""
async for chunk in response:
content += chunk.choices[0].delta.content or ""
print(f"content: {content}")
completion_tokens = litellm.token_counter(
model="gpt-35-turbo", text=content, count_response_tokens=True
)
return prompt_tokens, completion_tokens
@pytest.mark.asyncio
async def test_key_update():
@ -181,3 +217,49 @@ async def test_key_info():
random_key = key_gen["key"]
status = await get_key_info(session=session, get_key=key, call_key=random_key)
assert status == 403
@pytest.mark.asyncio
async def test_key_info_spend_values():
"""
- create key
- make completion call
- assert cost is expected value
"""
async with aiohttp.ClientSession() as session:
## Test Spend Update ##
# completion
# response = await chat_completion(session=session, key=key)
# prompt_cost, completion_cost = litellm.cost_per_token(
# model="azure/gpt-35-turbo",
# prompt_tokens=response["usage"]["prompt_tokens"],
# completion_tokens=response["usage"]["completion_tokens"],
# )
# response_cost = prompt_cost + completion_cost
# await asyncio.sleep(5) # allow db log to be updated
# key_info = await get_key_info(session=session, get_key=key, call_key=key)
# print(
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
# )
# assert response_cost == key_info["info"]["spend"]
## streaming
key_gen = await generate_key(session=session, i=0)
new_key = key_gen["key"]
prompt_tokens, completion_tokens = await chat_completion_streaming(
session=session, key=new_key
)
print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
prompt_cost, completion_cost = litellm.cost_per_token(
model="azure/gpt-35-turbo",
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
response_cost = prompt_cost + completion_cost
await asyncio.sleep(5) # allow db log to be updated
key_info = await get_key_info(
session=session, get_key=new_key, call_key=new_key
)
print(
f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
)
assert response_cost == key_info["info"]["spend"]

View file

@ -68,6 +68,7 @@ async def chat_completion(session, key):
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio

View file

@ -6,6 +6,9 @@ from dotenv import load_dotenv
load_dotenv()
import streamlit as st
import base64, os, json, uuid, requests
import pandas as pd
import plotly.express as px
import click
# Replace your_base_url with the actual URL where the proxy auth app is hosted
your_base_url = os.getenv("BASE_URL") # Example base URL
@ -75,7 +78,7 @@ def add_new_model():
and st.session_state.get("proxy_key", None) is None
):
st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
model_name = st.text_input(
@ -174,10 +177,70 @@ def list_models():
st.error(f"An error occurred while requesting models: {e}")
else:
st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
def spend_per_key():
import streamlit as st
import requests
# Check if the necessary configuration is available
if (
st.session_state.get("api_url", None) is not None
and st.session_state.get("proxy_key", None) is not None
):
# Make the GET request
try:
complete_url = ""
if isinstance(st.session_state["api_url"], str) and st.session_state[
"api_url"
].endswith("/"):
complete_url = f"{st.session_state['api_url']}/spend/keys"
else:
complete_url = f"{st.session_state['api_url']}/spend/keys"
response = requests.get(
complete_url,
headers={"Authorization": f"Bearer {st.session_state['proxy_key']}"},
)
# Check if the request was successful
if response.status_code == 200:
spend_per_key = response.json()
# Create DataFrame
spend_df = pd.DataFrame(spend_per_key)
# Display the spend per key as a graph
st.header("Spend ($) per API Key:")
top_10_df = spend_df.nlargest(10, "spend")
fig = px.bar(
top_10_df,
x="token",
y="spend",
title="Top 10 Spend per Key",
height=550, # Adjust the height
width=1200, # Adjust the width)
hover_data=["token", "spend", "user_id", "team_id"],
)
st.plotly_chart(fig)
# Display the spend per key as a table
st.write("Spend per Key - Full Table:")
st.table(spend_df)
else:
st.error(f"Failed to get models. Status code: {response.status_code}")
except Exception as e:
st.error(f"An error occurred while requesting models: {e}")
else:
st.warning(
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
def spend_per_user():
pass
def create_key():
import streamlit as st
import json, requests, uuid
@ -187,7 +250,7 @@ def create_key():
and st.session_state.get("proxy_key", None) is None
):
st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
duration = st.text_input("Duration - Can be in (h,m,s)", placeholder="1h")
@ -235,7 +298,7 @@ def update_config():
and st.session_state.get("proxy_key", None) is None
):
st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
st.markdown("#### Alerting")
@ -324,19 +387,25 @@ def update_config():
raise e
def admin_page(is_admin="NOT_GIVEN"):
def admin_page(is_admin="NOT_GIVEN", input_api_url=None, input_proxy_key=None):
# Display the form for the admin to set the proxy URL and allowed email subdomain
st.set_page_config(
layout="wide", # Use "wide" layout for more space
)
st.header("Admin Configuration")
st.session_state.setdefault("is_admin", is_admin)
# Add a navigation sidebar
st.sidebar.title("Navigation")
page = st.sidebar.radio(
"Go to",
(
"Connect to Proxy",
"View Spend Per Key",
"View Spend Per User",
"List Models",
"Update Config",
"Add Models",
"List Models",
"Create Key",
"End-User Auth",
),
@ -344,16 +413,23 @@ def admin_page(is_admin="NOT_GIVEN"):
# Display different pages based on navigation selection
if page == "Connect to Proxy":
# Use text inputs with intermediary variables
input_api_url = st.text_input(
"Proxy Endpoint",
value=st.session_state.get("api_url", ""),
placeholder="http://0.0.0.0:8000",
)
input_proxy_key = st.text_input(
"Proxy Key",
value=st.session_state.get("proxy_key", ""),
placeholder="sk-...",
)
if input_api_url is None:
input_api_url = st.text_input(
"Proxy Endpoint",
value=st.session_state.get("api_url", ""),
placeholder="http://0.0.0.0:8000",
)
else:
st.session_state["api_url"] = input_api_url
if input_proxy_key is None:
input_proxy_key = st.text_input(
"Proxy Key",
value=st.session_state.get("proxy_key", ""),
placeholder="sk-...",
)
else:
st.session_state["proxy_key"] = input_proxy_key
# When the "Save" button is clicked, update the session state
if st.button("Save"):
st.session_state["api_url"] = input_api_url
@ -369,6 +445,21 @@ def admin_page(is_admin="NOT_GIVEN"):
list_models()
elif page == "Create Key":
create_key()
elif page == "View Spend Per Key":
spend_per_key()
elif page == "View Spend Per User":
spend_per_user()
admin_page()
# admin_page()
@click.command()
@click.option("--proxy_endpoint", type=str, help="Proxy Endpoint")
@click.option("--proxy_master_key", type=str, help="Proxy Master Key")
def main(proxy_endpoint, proxy_master_key):
admin_page(input_api_url=proxy_endpoint, input_proxy_key=proxy_master_key)
if __name__ == "__main__":
main()