Merge branch 'main' into litellm_reset_key_budget

This commit is contained in:
Krish Dholakia 2024-01-23 18:10:32 -08:00 committed by GitHub
commit 9784d03d65
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 492 additions and 78 deletions

View file

@ -115,6 +115,25 @@ jobs:
pip install "pytest==7.3.1"
pip install "pytest-asyncio==0.21.1"
pip install aiohttp
pip install openai
python -m pip install --upgrade pip
python -m pip install -r .circleci/requirements.txt
pip install "pytest==7.3.1"
pip install "pytest-asyncio==0.21.1"
pip install mypy
pip install "google-generativeai>=0.3.2"
pip install "google-cloud-aiplatform>=1.38.0"
pip install "boto3>=1.28.57"
pip install langchain
pip install "langfuse>=2.0.0"
pip install numpydoc
pip install prisma
pip install "httpx==0.24.1"
pip install "gunicorn==21.2.0"
pip install "anyio==3.7.1"
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
# Run pytest and generate JUnit XML report
- run:
name: Build Docker image

View file

@ -98,7 +98,7 @@ def list_models():
st.error(f"An error occurred while requesting models: {e}")
else:
st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
@ -151,7 +151,7 @@ def create_key():
raise e
else:
st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)

View file

@ -598,9 +598,9 @@ async def track_cost_callback(
end_time=end_time,
)
else:
if (
kwargs["stream"] != True
or kwargs.get("complete_streaming_response", None) is not None
if kwargs["stream"] != True or (
kwargs["stream"] == True
and kwargs.get("complete_streaming_response") in kwargs
):
raise Exception(
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
@ -701,6 +701,7 @@ async def update_database(
valid_token.spend = new_spend
user_api_key_cache.set_cache(key=token, value=valid_token)
### UPDATE SPEND LOGS ###
async def _insert_spend_log_to_db():
# Helper to generate payload to log
verbose_proxy_logger.debug("inserting spend log to db")
@ -1438,6 +1439,28 @@ async def async_data_generator(response, user_api_key_dict):
yield f"data: {str(e)}\n\n"
def select_data_generator(response, user_api_key_dict):
try:
# since boto3 - sagemaker does not support async calls, we should use a sync data_generator
if (
hasattr(response, "custom_llm_provider")
and response.custom_llm_provider == "sagemaker"
):
return data_generator(
response=response,
)
else:
# default to async_data_generator
return async_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
except:
# worst case - use async_data_generator
return async_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
def get_litellm_model_info(model: dict = {}):
model_info = model.get("model_info", {})
model_to_lookup = model.get("litellm_params", {}).get("model", None)
@ -1679,11 +1702,12 @@ async def completion(
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
custom_headers = {"x-litellm-model-id": model_id}
selected_data_generator = select_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
return StreamingResponse(
async_data_generator(
user_api_key_dict=user_api_key_dict,
response=response,
),
selected_data_generator,
media_type="text/event-stream",
headers=custom_headers,
)
@ -1841,11 +1865,12 @@ async def chat_completion(
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
custom_headers = {"x-litellm-model-id": model_id}
selected_data_generator = select_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
return StreamingResponse(
async_data_generator(
user_api_key_dict=user_api_key_dict,
response=response,
),
selected_data_generator,
media_type="text/event-stream",
headers=custom_headers,
)
@ -2305,6 +2330,94 @@ async def info_key_fn(
)
@router.get(
"/spend/keys",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def spend_key_fn():
"""
View all keys created, ordered by spend
Example Request:
```
curl -X GET "http://0.0.0.0:8000/spend/keys" \
-H "Authorization: Bearer sk-1234"
```
"""
global prisma_client
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
key_info = await prisma_client.get_data(table_name="key", query_type="find_all")
return key_info
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
@router.get(
"/spend/logs",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def view_spend_logs(
request_id: Optional[str] = fastapi.Query(
default=None,
description="request_id to get spend logs for specific request_id. If none passed then pass spend logs for all requests",
),
):
"""
View all spend logs, if request_id is provided, only logs for that request_id will be returned
Example Request for all logs
```
curl -X GET "http://0.0.0.0:8000/spend/logs" \
-H "Authorization: Bearer sk-1234"
```
Example Request for specific request_id
```
curl -X GET "http://0.0.0.0:8000/spend/logs?request_id=chatcmpl-6dcb2540-d3d7-4e49-bb27-291f863f112e" \
-H "Authorization: Bearer sk-1234"
```
"""
global prisma_client
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
spend_logs = []
if request_id is not None:
spend_log = await prisma_client.get_data(
table_name="spend",
query_type="find_unique",
request_id=request_id,
)
return [spend_log]
else:
spend_logs = await prisma_client.get_data(
table_name="spend", query_type="find_all"
)
return spend_logs
return None
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
#### USER MANAGEMENT ####
@router.post(
"/user/new",

View file

@ -4,7 +4,7 @@ const openai = require('openai');
process.env.DEBUG=false;
async function runOpenAI() {
const client = new openai.OpenAI({
apiKey: 'sk-yPX56TDqBpr23W7ruFG3Yg',
apiKey: 'sk-JkKeNi6WpWDngBsghJ6B9g',
baseURL: 'http://0.0.0.0:8000'
});

View file

@ -361,7 +361,8 @@ class PrismaClient:
self,
token: Optional[str] = None,
user_id: Optional[str] = None,
table_name: Optional[Literal["user", "key", "config"]] = None,
request_id: Optional[str] = None,
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
query_type: Literal["find_unique", "find_all"] = "find_unique",
expires: Optional[datetime] = None,
reset_at: Optional[datetime] = None,
@ -411,6 +412,10 @@ class PrismaClient:
for r in response:
if isinstance(r.expires, datetime):
r.expires = r.expires.isoformat()
elif query_type == "find_all":
response = await self.db.litellm_verificationtoken.find_many(
order={"spend": "desc"},
)
print_verbose(f"PrismaClient: response={response}")
if response is not None:
return response
@ -427,6 +432,23 @@ class PrismaClient:
}
)
return response
elif table_name == "spend":
verbose_proxy_logger.debug(
f"PrismaClient: get_data: table_name == 'spend'"
)
if request_id is not None:
response = await self.db.litellm_spendlogs.find_unique( # type: ignore
where={
"request_id": request_id,
}
)
return response
else:
response = await self.db.litellm_spendlogs.find_many( # type: ignore
order={"startTime": "desc"},
)
return response
except Exception as e:
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
import traceback
@ -549,7 +571,6 @@ class PrismaClient:
db_data = self.jsonify_object(data=data)
if token is not None:
print_verbose(f"token: {token}")
if query_type == "update":
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
@ -558,7 +579,7 @@ class PrismaClient:
where={"token": token}, # type: ignore
data={**db_data}, # type: ignore
)
print_verbose(
verbose_proxy_logger.debug(
"\033[91m"
+ f"DB Token Table update succeeded {response}"
+ "\033[0m"
@ -885,10 +906,15 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4()))
api_key = metadata.get("user_api_key", "")
if api_key is not None and type(api_key) == str:
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
# hash the api_key
api_key = hash_token(api_key)
if "headers" in metadata and "authorization" in metadata["headers"]:
metadata["headers"].pop(
"authorization"
) # do not store the original `sk-..` api key in the db
payload = {
"request_id": id,
"call_type": call_type,

View file

@ -1408,9 +1408,15 @@ def test_completion_sagemaker_stream():
)
complete_streaming_response = ""
for chunk in response:
first_chunk_id, chunk_id = None, None
for i, chunk in enumerate(response):
print(chunk)
chunk_id = chunk.id
print(chunk_id)
if i == 0:
first_chunk_id = chunk_id
else:
assert chunk_id == first_chunk_id
complete_streaming_response += chunk.choices[0].delta.content or ""
# Add any assertions here to check the response
# print(response)

View file

@ -960,3 +960,29 @@ def test_router_anthropic_key_dynamic():
messages = [{"role": "user", "content": "Hey, how's it going?"}]
router.completion(model="anthropic-claude", messages=messages)
os.environ["ANTHROPIC_API_KEY"] = anthropic_api_key
def test_router_timeout():
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "os.environ/OPENAI_API_KEY",
},
}
]
router = Router(model_list=model_list)
messages = [{"role": "user", "content": "Hey, how's it going?"}]
start_time = time.time()
try:
res = router.completion(
model="gpt-3.5-turbo", messages=messages, timeout=0.0001
)
print(res)
pytest.fail("this should have timed out")
except litellm.exceptions.Timeout as e:
print("got timeout exception")
print(e)
print(vars(e))
pass

View file

@ -733,8 +733,15 @@ def test_completion_bedrock_claude_stream():
complete_response = ""
has_finish_reason = False
# Add any assertions here to check the response
first_chunk_id = None
for idx, chunk in enumerate(response):
# print
if idx == 0:
first_chunk_id = chunk.id
else:
assert (
chunk.id == first_chunk_id
), f"chunk ids do not match: {chunk.id} != first chunk id{first_chunk_id}"
chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished
complete_response += chunk

View file

@ -1067,7 +1067,6 @@ class Logging:
## if model in model cost map - log the response cost
## else set cost to None
verbose_logger.debug(f"Model={self.model}; result={result}")
verbose_logger.debug(f"self.stream: {self.stream}")
if (
result is not None
and (
@ -1109,6 +1108,12 @@ class Logging:
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
):
verbose_logger.debug(f"Logging Details LiteLLM-Success Call")
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time,
end_time=end_time,
result=result,
cache_hit=cache_hit,
)
# print(f"original response in success handler: {self.model_call_details['original_response']}")
try:
verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
@ -1124,6 +1129,8 @@ class Logging:
complete_streaming_response = litellm.stream_chunk_builder(
self.sync_streaming_chunks,
messages=self.model_call_details.get("messages", None),
start_time=start_time,
end_time=end_time,
)
except:
complete_streaming_response = None
@ -1137,13 +1144,19 @@ class Logging:
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time,
end_time=end_time,
result=result,
cache_hit=cache_hit,
try:
self.model_call_details["response_cost"] = litellm.completion_cost(
completion_response=complete_streaming_response,
)
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
)
except litellm.NotFoundError as e:
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
)
self.model_call_details["response_cost"] = None
for callback in litellm.success_callback:
try:
if callback == "lite_debugger":
@ -1487,6 +1500,19 @@ class Logging:
end_time=end_time,
)
if callable(callback): # custom logger functions
if self.stream:
if "complete_streaming_response" in self.model_call_details:
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=self.model_call_details[
"complete_streaming_response"
],
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
callback_func=callback,
)
else:
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=result,
@ -2915,17 +2941,25 @@ def cost_per_token(
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model_with_provider in model_cost_ref:
print_verbose(f"Looking up model={model_with_provider} in model_cost_map")
verbose_logger.debug(
f"Looking up model={model_with_provider} in model_cost_map"
)
verbose_logger.debug(
f"applying cost={model_cost_ref[model_with_provider]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
)
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens
)
verbose_logger.debug(
f"applying cost={model_cost_ref[model_with_provider]['output_cost_per_token']} for completion_tokens={completion_tokens}"
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["output_cost_per_token"]
* completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif "ft:gpt-3.5-turbo" in model:
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
verbose_logger.debug(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
prompt_tokens_cost_usd_dollar = (
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
@ -2936,17 +2970,23 @@ def cost_per_token(
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model in litellm.azure_llms:
print_verbose(f"Cost Tracking: {model} is an Azure LLM")
verbose_logger.debug(f"Cost Tracking: {model} is an Azure LLM")
model = litellm.azure_llms[model]
verbose_logger.debug(
f"applying cost={model_cost_ref[model]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
)
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
)
verbose_logger.debug(
f"applying cost={model_cost_ref[model]['output_cost_per_token']} for completion_tokens={completion_tokens}"
)
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model in litellm.azure_embedding_models:
print_verbose(f"Cost Tracking: {model} is an Azure Embedding Model")
verbose_logger.debug(f"Cost Tracking: {model} is an Azure Embedding Model")
model = litellm.azure_embedding_models[model]
prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
@ -7061,6 +7101,7 @@ class CustomStreamWrapper:
self._hidden_params = {
"model_id": (_model_info.get("id", None))
} # returned as x-litellm-model-id response header in proxy
self.response_id = None
def __iter__(self):
return self
@ -7633,6 +7674,10 @@ class CustomStreamWrapper:
def chunk_creator(self, chunk):
model_response = ModelResponse(stream=True, model=self.model)
if self.response_id is not None:
model_response.id = self.response_id
else:
self.response_id = model_response.id
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
model_response.choices = [StreamingChoices()]
model_response.choices[0].finish_reason = None
@ -7752,10 +7797,8 @@ class CustomStreamWrapper:
]
self.sent_last_chunk = True
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING")
new_chunk = next(self.completion_stream)
completion_obj["content"] = new_chunk
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
completion_obj["content"] = chunk
elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0:
if self.sent_last_chunk:
@ -7874,7 +7917,7 @@ class CustomStreamWrapper:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
print_verbose(f"model_response: {model_response}")
print_verbose(f"returning model_response: {model_response}")
return model_response
else:
return

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "1.18.10"
version = "1.18.11"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT"
@ -63,7 +63,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
version = "1.18.10"
version = "1.18.11"
version_files = [
"pyproject.toml:^version"
]

View file

@ -2,15 +2,22 @@
## Tests /key endpoints.
import pytest
import asyncio
import asyncio, time
import aiohttp
from openai import AsyncOpenAI
import sys, os
sys.path.insert(
0, os.path.abspath("../")
) # Adds the parent directory to the system path
import litellm
async def generate_key(session, i):
url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"models": ["azure-models"],
"models": ["azure-models", "gpt-4"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None,
}
@ -82,6 +89,35 @@ async def chat_completion(session, key, model="gpt-4"):
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
async def chat_completion_streaming(session, key, model="gpt-4"):
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000")
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": f"Hello! {time.time()}"},
]
prompt_tokens = litellm.token_counter(model="gpt-35-turbo", messages=messages)
data = {
"model": model,
"messages": messages,
"stream": True,
}
response = await client.chat.completions.create(**data)
content = ""
async for chunk in response:
content += chunk.choices[0].delta.content or ""
print(f"content: {content}")
completion_tokens = litellm.token_counter(
model="gpt-35-turbo", text=content, count_response_tokens=True
)
return prompt_tokens, completion_tokens
@pytest.mark.asyncio
async def test_key_update():
@ -181,3 +217,49 @@ async def test_key_info():
random_key = key_gen["key"]
status = await get_key_info(session=session, get_key=key, call_key=random_key)
assert status == 403
@pytest.mark.asyncio
async def test_key_info_spend_values():
"""
- create key
- make completion call
- assert cost is expected value
"""
async with aiohttp.ClientSession() as session:
## Test Spend Update ##
# completion
# response = await chat_completion(session=session, key=key)
# prompt_cost, completion_cost = litellm.cost_per_token(
# model="azure/gpt-35-turbo",
# prompt_tokens=response["usage"]["prompt_tokens"],
# completion_tokens=response["usage"]["completion_tokens"],
# )
# response_cost = prompt_cost + completion_cost
# await asyncio.sleep(5) # allow db log to be updated
# key_info = await get_key_info(session=session, get_key=key, call_key=key)
# print(
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
# )
# assert response_cost == key_info["info"]["spend"]
## streaming
key_gen = await generate_key(session=session, i=0)
new_key = key_gen["key"]
prompt_tokens, completion_tokens = await chat_completion_streaming(
session=session, key=new_key
)
print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
prompt_cost, completion_cost = litellm.cost_per_token(
model="azure/gpt-35-turbo",
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
response_cost = prompt_cost + completion_cost
await asyncio.sleep(5) # allow db log to be updated
key_info = await get_key_info(
session=session, get_key=new_key, call_key=new_key
)
print(
f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
)
assert response_cost == key_info["info"]["spend"]

View file

@ -68,6 +68,7 @@ async def chat_completion(session, key):
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio

View file

@ -6,6 +6,9 @@ from dotenv import load_dotenv
load_dotenv()
import streamlit as st
import base64, os, json, uuid, requests
import pandas as pd
import plotly.express as px
import click
# Replace your_base_url with the actual URL where the proxy auth app is hosted
your_base_url = os.getenv("BASE_URL") # Example base URL
@ -75,7 +78,7 @@ def add_new_model():
and st.session_state.get("proxy_key", None) is None
):
st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
model_name = st.text_input(
@ -174,10 +177,70 @@ def list_models():
st.error(f"An error occurred while requesting models: {e}")
else:
st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
def spend_per_key():
import streamlit as st
import requests
# Check if the necessary configuration is available
if (
st.session_state.get("api_url", None) is not None
and st.session_state.get("proxy_key", None) is not None
):
# Make the GET request
try:
complete_url = ""
if isinstance(st.session_state["api_url"], str) and st.session_state[
"api_url"
].endswith("/"):
complete_url = f"{st.session_state['api_url']}/spend/keys"
else:
complete_url = f"{st.session_state['api_url']}/spend/keys"
response = requests.get(
complete_url,
headers={"Authorization": f"Bearer {st.session_state['proxy_key']}"},
)
# Check if the request was successful
if response.status_code == 200:
spend_per_key = response.json()
# Create DataFrame
spend_df = pd.DataFrame(spend_per_key)
# Display the spend per key as a graph
st.header("Spend ($) per API Key:")
top_10_df = spend_df.nlargest(10, "spend")
fig = px.bar(
top_10_df,
x="token",
y="spend",
title="Top 10 Spend per Key",
height=550, # Adjust the height
width=1200, # Adjust the width)
hover_data=["token", "spend", "user_id", "team_id"],
)
st.plotly_chart(fig)
# Display the spend per key as a table
st.write("Spend per Key - Full Table:")
st.table(spend_df)
else:
st.error(f"Failed to get models. Status code: {response.status_code}")
except Exception as e:
st.error(f"An error occurred while requesting models: {e}")
else:
st.warning(
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
def spend_per_user():
pass
def create_key():
import streamlit as st
import json, requests, uuid
@ -187,7 +250,7 @@ def create_key():
and st.session_state.get("proxy_key", None) is None
):
st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
duration = st.text_input("Duration - Can be in (h,m,s)", placeholder="1h")
@ -235,7 +298,7 @@ def update_config():
and st.session_state.get("proxy_key", None) is None
):
st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
st.markdown("#### Alerting")
@ -324,19 +387,25 @@ def update_config():
raise e
def admin_page(is_admin="NOT_GIVEN"):
def admin_page(is_admin="NOT_GIVEN", input_api_url=None, input_proxy_key=None):
# Display the form for the admin to set the proxy URL and allowed email subdomain
st.set_page_config(
layout="wide", # Use "wide" layout for more space
)
st.header("Admin Configuration")
st.session_state.setdefault("is_admin", is_admin)
# Add a navigation sidebar
st.sidebar.title("Navigation")
page = st.sidebar.radio(
"Go to",
(
"Connect to Proxy",
"View Spend Per Key",
"View Spend Per User",
"List Models",
"Update Config",
"Add Models",
"List Models",
"Create Key",
"End-User Auth",
),
@ -344,16 +413,23 @@ def admin_page(is_admin="NOT_GIVEN"):
# Display different pages based on navigation selection
if page == "Connect to Proxy":
# Use text inputs with intermediary variables
if input_api_url is None:
input_api_url = st.text_input(
"Proxy Endpoint",
value=st.session_state.get("api_url", ""),
placeholder="http://0.0.0.0:8000",
)
else:
st.session_state["api_url"] = input_api_url
if input_proxy_key is None:
input_proxy_key = st.text_input(
"Proxy Key",
value=st.session_state.get("proxy_key", ""),
placeholder="sk-...",
)
else:
st.session_state["proxy_key"] = input_proxy_key
# When the "Save" button is clicked, update the session state
if st.button("Save"):
st.session_state["api_url"] = input_api_url
@ -369,6 +445,21 @@ def admin_page(is_admin="NOT_GIVEN"):
list_models()
elif page == "Create Key":
create_key()
elif page == "View Spend Per Key":
spend_per_key()
elif page == "View Spend Per User":
spend_per_user()
admin_page()
# admin_page()
@click.command()
@click.option("--proxy_endpoint", type=str, help="Proxy Endpoint")
@click.option("--proxy_master_key", type=str, help="Proxy Master Key")
def main(proxy_endpoint, proxy_master_key):
admin_page(input_api_url=proxy_endpoint, input_proxy_key=proxy_master_key)
if __name__ == "__main__":
main()