mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Merge branch 'main' into litellm_reset_key_budget
This commit is contained in:
commit
9784d03d65
13 changed files with 492 additions and 78 deletions
|
@ -115,6 +115,25 @@ jobs:
|
||||||
pip install "pytest==7.3.1"
|
pip install "pytest==7.3.1"
|
||||||
pip install "pytest-asyncio==0.21.1"
|
pip install "pytest-asyncio==0.21.1"
|
||||||
pip install aiohttp
|
pip install aiohttp
|
||||||
|
pip install openai
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install -r .circleci/requirements.txt
|
||||||
|
pip install "pytest==7.3.1"
|
||||||
|
pip install "pytest-asyncio==0.21.1"
|
||||||
|
pip install mypy
|
||||||
|
pip install "google-generativeai>=0.3.2"
|
||||||
|
pip install "google-cloud-aiplatform>=1.38.0"
|
||||||
|
pip install "boto3>=1.28.57"
|
||||||
|
pip install langchain
|
||||||
|
pip install "langfuse>=2.0.0"
|
||||||
|
pip install numpydoc
|
||||||
|
pip install prisma
|
||||||
|
pip install "httpx==0.24.1"
|
||||||
|
pip install "gunicorn==21.2.0"
|
||||||
|
pip install "anyio==3.7.1"
|
||||||
|
pip install "aiodynamo==23.10.1"
|
||||||
|
pip install "asyncio==3.4.3"
|
||||||
|
pip install "PyGithub==1.59.1"
|
||||||
# Run pytest and generate JUnit XML report
|
# Run pytest and generate JUnit XML report
|
||||||
- run:
|
- run:
|
||||||
name: Build Docker image
|
name: Build Docker image
|
||||||
|
|
|
@ -98,7 +98,7 @@ def list_models():
|
||||||
st.error(f"An error occurred while requesting models: {e}")
|
st.error(f"An error occurred while requesting models: {e}")
|
||||||
else:
|
else:
|
||||||
st.warning(
|
st.warning(
|
||||||
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
|
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -151,7 +151,7 @@ def create_key():
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
st.warning(
|
st.warning(
|
||||||
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
|
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -598,9 +598,9 @@ async def track_cost_callback(
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (
|
if kwargs["stream"] != True or (
|
||||||
kwargs["stream"] != True
|
kwargs["stream"] == True
|
||||||
or kwargs.get("complete_streaming_response", None) is not None
|
and kwargs.get("complete_streaming_response") in kwargs
|
||||||
):
|
):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||||
|
@ -701,6 +701,7 @@ async def update_database(
|
||||||
valid_token.spend = new_spend
|
valid_token.spend = new_spend
|
||||||
user_api_key_cache.set_cache(key=token, value=valid_token)
|
user_api_key_cache.set_cache(key=token, value=valid_token)
|
||||||
|
|
||||||
|
### UPDATE SPEND LOGS ###
|
||||||
async def _insert_spend_log_to_db():
|
async def _insert_spend_log_to_db():
|
||||||
# Helper to generate payload to log
|
# Helper to generate payload to log
|
||||||
verbose_proxy_logger.debug("inserting spend log to db")
|
verbose_proxy_logger.debug("inserting spend log to db")
|
||||||
|
@ -1438,6 +1439,28 @@ async def async_data_generator(response, user_api_key_dict):
|
||||||
yield f"data: {str(e)}\n\n"
|
yield f"data: {str(e)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def select_data_generator(response, user_api_key_dict):
|
||||||
|
try:
|
||||||
|
# since boto3 - sagemaker does not support async calls, we should use a sync data_generator
|
||||||
|
if (
|
||||||
|
hasattr(response, "custom_llm_provider")
|
||||||
|
and response.custom_llm_provider == "sagemaker"
|
||||||
|
):
|
||||||
|
return data_generator(
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# default to async_data_generator
|
||||||
|
return async_data_generator(
|
||||||
|
response=response, user_api_key_dict=user_api_key_dict
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
# worst case - use async_data_generator
|
||||||
|
return async_data_generator(
|
||||||
|
response=response, user_api_key_dict=user_api_key_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_litellm_model_info(model: dict = {}):
|
def get_litellm_model_info(model: dict = {}):
|
||||||
model_info = model.get("model_info", {})
|
model_info = model.get("model_info", {})
|
||||||
model_to_lookup = model.get("litellm_params", {}).get("model", None)
|
model_to_lookup = model.get("litellm_params", {}).get("model", None)
|
||||||
|
@ -1679,11 +1702,12 @@ async def completion(
|
||||||
"stream" in data and data["stream"] == True
|
"stream" in data and data["stream"] == True
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
custom_headers = {"x-litellm-model-id": model_id}
|
custom_headers = {"x-litellm-model-id": model_id}
|
||||||
|
selected_data_generator = select_data_generator(
|
||||||
|
response=response, user_api_key_dict=user_api_key_dict
|
||||||
|
)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
async_data_generator(
|
selected_data_generator,
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
response=response,
|
|
||||||
),
|
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers=custom_headers,
|
headers=custom_headers,
|
||||||
)
|
)
|
||||||
|
@ -1841,11 +1865,12 @@ async def chat_completion(
|
||||||
"stream" in data and data["stream"] == True
|
"stream" in data and data["stream"] == True
|
||||||
): # use generate_responses to stream responses
|
): # use generate_responses to stream responses
|
||||||
custom_headers = {"x-litellm-model-id": model_id}
|
custom_headers = {"x-litellm-model-id": model_id}
|
||||||
|
selected_data_generator = select_data_generator(
|
||||||
|
response=response, user_api_key_dict=user_api_key_dict
|
||||||
|
)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
async_data_generator(
|
selected_data_generator,
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
response=response,
|
|
||||||
),
|
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers=custom_headers,
|
headers=custom_headers,
|
||||||
)
|
)
|
||||||
|
@ -2305,6 +2330,94 @@ async def info_key_fn(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/spend/keys",
|
||||||
|
tags=["Budget & Spend Tracking"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def spend_key_fn():
|
||||||
|
"""
|
||||||
|
View all keys created, ordered by spend
|
||||||
|
|
||||||
|
Example Request:
|
||||||
|
```
|
||||||
|
curl -X GET "http://0.0.0.0:8000/spend/keys" \
|
||||||
|
-H "Authorization: Bearer sk-1234"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
global prisma_client
|
||||||
|
try:
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception(
|
||||||
|
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
key_info = await prisma_client.get_data(table_name="key", query_type="find_all")
|
||||||
|
|
||||||
|
return key_info
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={"error": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/spend/logs",
|
||||||
|
tags=["Budget & Spend Tracking"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def view_spend_logs(
|
||||||
|
request_id: Optional[str] = fastapi.Query(
|
||||||
|
default=None,
|
||||||
|
description="request_id to get spend logs for specific request_id. If none passed then pass spend logs for all requests",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
View all spend logs, if request_id is provided, only logs for that request_id will be returned
|
||||||
|
|
||||||
|
Example Request for all logs
|
||||||
|
```
|
||||||
|
curl -X GET "http://0.0.0.0:8000/spend/logs" \
|
||||||
|
-H "Authorization: Bearer sk-1234"
|
||||||
|
```
|
||||||
|
|
||||||
|
Example Request for specific request_id
|
||||||
|
```
|
||||||
|
curl -X GET "http://0.0.0.0:8000/spend/logs?request_id=chatcmpl-6dcb2540-d3d7-4e49-bb27-291f863f112e" \
|
||||||
|
-H "Authorization: Bearer sk-1234"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
global prisma_client
|
||||||
|
try:
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception(
|
||||||
|
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||||
|
)
|
||||||
|
spend_logs = []
|
||||||
|
if request_id is not None:
|
||||||
|
spend_log = await prisma_client.get_data(
|
||||||
|
table_name="spend",
|
||||||
|
query_type="find_unique",
|
||||||
|
request_id=request_id,
|
||||||
|
)
|
||||||
|
return [spend_log]
|
||||||
|
else:
|
||||||
|
spend_logs = await prisma_client.get_data(
|
||||||
|
table_name="spend", query_type="find_all"
|
||||||
|
)
|
||||||
|
return spend_logs
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={"error": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
#### USER MANAGEMENT ####
|
#### USER MANAGEMENT ####
|
||||||
@router.post(
|
@router.post(
|
||||||
"/user/new",
|
"/user/new",
|
||||||
|
|
|
@ -4,7 +4,7 @@ const openai = require('openai');
|
||||||
process.env.DEBUG=false;
|
process.env.DEBUG=false;
|
||||||
async function runOpenAI() {
|
async function runOpenAI() {
|
||||||
const client = new openai.OpenAI({
|
const client = new openai.OpenAI({
|
||||||
apiKey: 'sk-yPX56TDqBpr23W7ruFG3Yg',
|
apiKey: 'sk-JkKeNi6WpWDngBsghJ6B9g',
|
||||||
baseURL: 'http://0.0.0.0:8000'
|
baseURL: 'http://0.0.0.0:8000'
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -361,7 +361,8 @@ class PrismaClient:
|
||||||
self,
|
self,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
table_name: Optional[Literal["user", "key", "config"]] = None,
|
request_id: Optional[str] = None,
|
||||||
|
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
|
||||||
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
||||||
expires: Optional[datetime] = None,
|
expires: Optional[datetime] = None,
|
||||||
reset_at: Optional[datetime] = None,
|
reset_at: Optional[datetime] = None,
|
||||||
|
@ -411,6 +412,10 @@ class PrismaClient:
|
||||||
for r in response:
|
for r in response:
|
||||||
if isinstance(r.expires, datetime):
|
if isinstance(r.expires, datetime):
|
||||||
r.expires = r.expires.isoformat()
|
r.expires = r.expires.isoformat()
|
||||||
|
elif query_type == "find_all":
|
||||||
|
response = await self.db.litellm_verificationtoken.find_many(
|
||||||
|
order={"spend": "desc"},
|
||||||
|
)
|
||||||
print_verbose(f"PrismaClient: response={response}")
|
print_verbose(f"PrismaClient: response={response}")
|
||||||
if response is not None:
|
if response is not None:
|
||||||
return response
|
return response
|
||||||
|
@ -427,6 +432,23 @@ class PrismaClient:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
elif table_name == "spend":
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"PrismaClient: get_data: table_name == 'spend'"
|
||||||
|
)
|
||||||
|
if request_id is not None:
|
||||||
|
response = await self.db.litellm_spendlogs.find_unique( # type: ignore
|
||||||
|
where={
|
||||||
|
"request_id": request_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
response = await self.db.litellm_spendlogs.find_many( # type: ignore
|
||||||
|
order={"startTime": "desc"},
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -549,21 +571,20 @@ class PrismaClient:
|
||||||
db_data = self.jsonify_object(data=data)
|
db_data = self.jsonify_object(data=data)
|
||||||
if token is not None:
|
if token is not None:
|
||||||
print_verbose(f"token: {token}")
|
print_verbose(f"token: {token}")
|
||||||
if query_type == "update":
|
# check if plain text or hash
|
||||||
# check if plain text or hash
|
if token.startswith("sk-"):
|
||||||
if token.startswith("sk-"):
|
token = self.hash_token(token=token)
|
||||||
token = self.hash_token(token=token)
|
db_data["token"] = token
|
||||||
db_data["token"] = token
|
response = await self.db.litellm_verificationtoken.update(
|
||||||
response = await self.db.litellm_verificationtoken.update(
|
where={"token": token}, # type: ignore
|
||||||
where={"token": token}, # type: ignore
|
data={**db_data}, # type: ignore
|
||||||
data={**db_data}, # type: ignore
|
)
|
||||||
)
|
verbose_proxy_logger.debug(
|
||||||
print_verbose(
|
"\033[91m"
|
||||||
"\033[91m"
|
+ f"DB Token Table update succeeded {response}"
|
||||||
+ f"DB Token Table update succeeded {response}"
|
+ "\033[0m"
|
||||||
+ "\033[0m"
|
)
|
||||||
)
|
return {"token": token, "data": db_data}
|
||||||
return {"token": token, "data": db_data}
|
|
||||||
elif user_id is not None:
|
elif user_id is not None:
|
||||||
"""
|
"""
|
||||||
If data['spend'] + data['user'], update the user table with spend info as well
|
If data['spend'] + data['user'], update the user table with spend info as well
|
||||||
|
@ -885,10 +906,15 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
|
||||||
usage = response_obj["usage"]
|
usage = response_obj["usage"]
|
||||||
id = response_obj.get("id", str(uuid.uuid4()))
|
id = response_obj.get("id", str(uuid.uuid4()))
|
||||||
api_key = metadata.get("user_api_key", "")
|
api_key = metadata.get("user_api_key", "")
|
||||||
if api_key is not None and type(api_key) == str:
|
if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
|
||||||
# hash the api_key
|
# hash the api_key
|
||||||
api_key = hash_token(api_key)
|
api_key = hash_token(api_key)
|
||||||
|
|
||||||
|
if "headers" in metadata and "authorization" in metadata["headers"]:
|
||||||
|
metadata["headers"].pop(
|
||||||
|
"authorization"
|
||||||
|
) # do not store the original `sk-..` api key in the db
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"request_id": id,
|
"request_id": id,
|
||||||
"call_type": call_type,
|
"call_type": call_type,
|
||||||
|
|
|
@ -1408,9 +1408,15 @@ def test_completion_sagemaker_stream():
|
||||||
)
|
)
|
||||||
|
|
||||||
complete_streaming_response = ""
|
complete_streaming_response = ""
|
||||||
|
first_chunk_id, chunk_id = None, None
|
||||||
for chunk in response:
|
for i, chunk in enumerate(response):
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
chunk_id = chunk.id
|
||||||
|
print(chunk_id)
|
||||||
|
if i == 0:
|
||||||
|
first_chunk_id = chunk_id
|
||||||
|
else:
|
||||||
|
assert chunk_id == first_chunk_id
|
||||||
complete_streaming_response += chunk.choices[0].delta.content or ""
|
complete_streaming_response += chunk.choices[0].delta.content or ""
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
# print(response)
|
# print(response)
|
||||||
|
|
|
@ -960,3 +960,29 @@ def test_router_anthropic_key_dynamic():
|
||||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
router.completion(model="anthropic-claude", messages=messages)
|
router.completion(model="anthropic-claude", messages=messages)
|
||||||
os.environ["ANTHROPIC_API_KEY"] = anthropic_api_key
|
os.environ["ANTHROPIC_API_KEY"] = anthropic_api_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_timeout():
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": "os.environ/OPENAI_API_KEY",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
res = router.completion(
|
||||||
|
model="gpt-3.5-turbo", messages=messages, timeout=0.0001
|
||||||
|
)
|
||||||
|
print(res)
|
||||||
|
pytest.fail("this should have timed out")
|
||||||
|
except litellm.exceptions.Timeout as e:
|
||||||
|
print("got timeout exception")
|
||||||
|
print(e)
|
||||||
|
print(vars(e))
|
||||||
|
pass
|
||||||
|
|
|
@ -733,8 +733,15 @@ def test_completion_bedrock_claude_stream():
|
||||||
complete_response = ""
|
complete_response = ""
|
||||||
has_finish_reason = False
|
has_finish_reason = False
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
|
first_chunk_id = None
|
||||||
for idx, chunk in enumerate(response):
|
for idx, chunk in enumerate(response):
|
||||||
# print
|
# print
|
||||||
|
if idx == 0:
|
||||||
|
first_chunk_id = chunk.id
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
chunk.id == first_chunk_id
|
||||||
|
), f"chunk ids do not match: {chunk.id} != first chunk id{first_chunk_id}"
|
||||||
chunk, finished = streaming_format_tests(idx, chunk)
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
has_finish_reason = finished
|
has_finish_reason = finished
|
||||||
complete_response += chunk
|
complete_response += chunk
|
||||||
|
|
|
@ -1067,7 +1067,6 @@ class Logging:
|
||||||
## if model in model cost map - log the response cost
|
## if model in model cost map - log the response cost
|
||||||
## else set cost to None
|
## else set cost to None
|
||||||
verbose_logger.debug(f"Model={self.model}; result={result}")
|
verbose_logger.debug(f"Model={self.model}; result={result}")
|
||||||
verbose_logger.debug(f"self.stream: {self.stream}")
|
|
||||||
if (
|
if (
|
||||||
result is not None
|
result is not None
|
||||||
and (
|
and (
|
||||||
|
@ -1109,6 +1108,12 @@ class Logging:
|
||||||
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
|
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
|
||||||
):
|
):
|
||||||
verbose_logger.debug(f"Logging Details LiteLLM-Success Call")
|
verbose_logger.debug(f"Logging Details LiteLLM-Success Call")
|
||||||
|
start_time, end_time, result = self._success_handler_helper_fn(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
result=result,
|
||||||
|
cache_hit=cache_hit,
|
||||||
|
)
|
||||||
# print(f"original response in success handler: {self.model_call_details['original_response']}")
|
# print(f"original response in success handler: {self.model_call_details['original_response']}")
|
||||||
try:
|
try:
|
||||||
verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
|
verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
|
||||||
|
@ -1124,6 +1129,8 @@ class Logging:
|
||||||
complete_streaming_response = litellm.stream_chunk_builder(
|
complete_streaming_response = litellm.stream_chunk_builder(
|
||||||
self.sync_streaming_chunks,
|
self.sync_streaming_chunks,
|
||||||
messages=self.model_call_details.get("messages", None),
|
messages=self.model_call_details.get("messages", None),
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
complete_streaming_response = None
|
complete_streaming_response = None
|
||||||
|
@ -1137,13 +1144,19 @@ class Logging:
|
||||||
self.model_call_details[
|
self.model_call_details[
|
||||||
"complete_streaming_response"
|
"complete_streaming_response"
|
||||||
] = complete_streaming_response
|
] = complete_streaming_response
|
||||||
|
try:
|
||||||
|
self.model_call_details["response_cost"] = litellm.completion_cost(
|
||||||
|
completion_response=complete_streaming_response,
|
||||||
|
)
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
||||||
|
)
|
||||||
|
except litellm.NotFoundError as e:
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Model={self.model} not found in completion cost map."
|
||||||
|
)
|
||||||
|
self.model_call_details["response_cost"] = None
|
||||||
|
|
||||||
start_time, end_time, result = self._success_handler_helper_fn(
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
result=result,
|
|
||||||
cache_hit=cache_hit,
|
|
||||||
)
|
|
||||||
for callback in litellm.success_callback:
|
for callback in litellm.success_callback:
|
||||||
try:
|
try:
|
||||||
if callback == "lite_debugger":
|
if callback == "lite_debugger":
|
||||||
|
@ -1487,14 +1500,27 @@ class Logging:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
if callable(callback): # custom logger functions
|
if callable(callback): # custom logger functions
|
||||||
await customLogger.async_log_event(
|
if self.stream:
|
||||||
kwargs=self.model_call_details,
|
if "complete_streaming_response" in self.model_call_details:
|
||||||
response_obj=result,
|
await customLogger.async_log_event(
|
||||||
start_time=start_time,
|
kwargs=self.model_call_details,
|
||||||
end_time=end_time,
|
response_obj=self.model_call_details[
|
||||||
print_verbose=print_verbose,
|
"complete_streaming_response"
|
||||||
callback_func=callback,
|
],
|
||||||
)
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
callback_func=callback,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await customLogger.async_log_event(
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
response_obj=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
callback_func=callback,
|
||||||
|
)
|
||||||
if callback == "dynamodb":
|
if callback == "dynamodb":
|
||||||
global dynamoLogger
|
global dynamoLogger
|
||||||
if dynamoLogger is None:
|
if dynamoLogger is None:
|
||||||
|
@ -2915,17 +2941,25 @@ def cost_per_token(
|
||||||
)
|
)
|
||||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
elif model_with_provider in model_cost_ref:
|
elif model_with_provider in model_cost_ref:
|
||||||
print_verbose(f"Looking up model={model_with_provider} in model_cost_map")
|
verbose_logger.debug(
|
||||||
|
f"Looking up model={model_with_provider} in model_cost_map"
|
||||||
|
)
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"applying cost={model_cost_ref[model_with_provider]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
|
||||||
|
)
|
||||||
prompt_tokens_cost_usd_dollar = (
|
prompt_tokens_cost_usd_dollar = (
|
||||||
model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens
|
model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens
|
||||||
)
|
)
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"applying cost={model_cost_ref[model_with_provider]['output_cost_per_token']} for completion_tokens={completion_tokens}"
|
||||||
|
)
|
||||||
completion_tokens_cost_usd_dollar = (
|
completion_tokens_cost_usd_dollar = (
|
||||||
model_cost_ref[model_with_provider]["output_cost_per_token"]
|
model_cost_ref[model_with_provider]["output_cost_per_token"]
|
||||||
* completion_tokens
|
* completion_tokens
|
||||||
)
|
)
|
||||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
elif "ft:gpt-3.5-turbo" in model:
|
elif "ft:gpt-3.5-turbo" in model:
|
||||||
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
|
verbose_logger.debug(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
|
||||||
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
|
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
|
||||||
prompt_tokens_cost_usd_dollar = (
|
prompt_tokens_cost_usd_dollar = (
|
||||||
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
|
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
|
||||||
|
@ -2936,17 +2970,23 @@ def cost_per_token(
|
||||||
)
|
)
|
||||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
elif model in litellm.azure_llms:
|
elif model in litellm.azure_llms:
|
||||||
print_verbose(f"Cost Tracking: {model} is an Azure LLM")
|
verbose_logger.debug(f"Cost Tracking: {model} is an Azure LLM")
|
||||||
model = litellm.azure_llms[model]
|
model = litellm.azure_llms[model]
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"applying cost={model_cost_ref[model]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
|
||||||
|
)
|
||||||
prompt_tokens_cost_usd_dollar = (
|
prompt_tokens_cost_usd_dollar = (
|
||||||
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||||
)
|
)
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"applying cost={model_cost_ref[model]['output_cost_per_token']} for completion_tokens={completion_tokens}"
|
||||||
|
)
|
||||||
completion_tokens_cost_usd_dollar = (
|
completion_tokens_cost_usd_dollar = (
|
||||||
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
||||||
)
|
)
|
||||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
elif model in litellm.azure_embedding_models:
|
elif model in litellm.azure_embedding_models:
|
||||||
print_verbose(f"Cost Tracking: {model} is an Azure Embedding Model")
|
verbose_logger.debug(f"Cost Tracking: {model} is an Azure Embedding Model")
|
||||||
model = litellm.azure_embedding_models[model]
|
model = litellm.azure_embedding_models[model]
|
||||||
prompt_tokens_cost_usd_dollar = (
|
prompt_tokens_cost_usd_dollar = (
|
||||||
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
|
||||||
|
@ -7061,6 +7101,7 @@ class CustomStreamWrapper:
|
||||||
self._hidden_params = {
|
self._hidden_params = {
|
||||||
"model_id": (_model_info.get("id", None))
|
"model_id": (_model_info.get("id", None))
|
||||||
} # returned as x-litellm-model-id response header in proxy
|
} # returned as x-litellm-model-id response header in proxy
|
||||||
|
self.response_id = None
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
@ -7633,6 +7674,10 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
def chunk_creator(self, chunk):
|
def chunk_creator(self, chunk):
|
||||||
model_response = ModelResponse(stream=True, model=self.model)
|
model_response = ModelResponse(stream=True, model=self.model)
|
||||||
|
if self.response_id is not None:
|
||||||
|
model_response.id = self.response_id
|
||||||
|
else:
|
||||||
|
self.response_id = model_response.id
|
||||||
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
|
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
|
||||||
model_response.choices = [StreamingChoices()]
|
model_response.choices = [StreamingChoices()]
|
||||||
model_response.choices[0].finish_reason = None
|
model_response.choices[0].finish_reason = None
|
||||||
|
@ -7752,10 +7797,8 @@ class CustomStreamWrapper:
|
||||||
]
|
]
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
elif self.custom_llm_provider == "sagemaker":
|
elif self.custom_llm_provider == "sagemaker":
|
||||||
print_verbose(f"ENTERS SAGEMAKER STREAMING")
|
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
||||||
new_chunk = next(self.completion_stream)
|
completion_obj["content"] = chunk
|
||||||
|
|
||||||
completion_obj["content"] = new_chunk
|
|
||||||
elif self.custom_llm_provider == "petals":
|
elif self.custom_llm_provider == "petals":
|
||||||
if len(self.completion_stream) == 0:
|
if len(self.completion_stream) == 0:
|
||||||
if self.sent_last_chunk:
|
if self.sent_last_chunk:
|
||||||
|
@ -7874,7 +7917,7 @@ class CustomStreamWrapper:
|
||||||
completion_obj["role"] = "assistant"
|
completion_obj["role"] = "assistant"
|
||||||
self.sent_first_chunk = True
|
self.sent_first_chunk = True
|
||||||
model_response.choices[0].delta = Delta(**completion_obj)
|
model_response.choices[0].delta = Delta(**completion_obj)
|
||||||
print_verbose(f"model_response: {model_response}")
|
print_verbose(f"returning model_response: {model_response}")
|
||||||
return model_response
|
return model_response
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.18.10"
|
version = "1.18.11"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -63,7 +63,7 @@ requires = ["poetry-core", "wheel"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.commitizen]
|
[tool.commitizen]
|
||||||
version = "1.18.10"
|
version = "1.18.11"
|
||||||
version_files = [
|
version_files = [
|
||||||
"pyproject.toml:^version"
|
"pyproject.toml:^version"
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,15 +2,22 @@
|
||||||
## Tests /key endpoints.
|
## Tests /key endpoints.
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
import asyncio, time
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import sys, os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
async def generate_key(session, i):
|
async def generate_key(session, i):
|
||||||
url = "http://0.0.0.0:4000/key/generate"
|
url = "http://0.0.0.0:4000/key/generate"
|
||||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
data = {
|
data = {
|
||||||
"models": ["azure-models"],
|
"models": ["azure-models", "gpt-4"],
|
||||||
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||||
"duration": None,
|
"duration": None,
|
||||||
}
|
}
|
||||||
|
@ -82,6 +89,35 @@ async def chat_completion(session, key, model="gpt-4"):
|
||||||
if status != 200:
|
if status != 200:
|
||||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_completion_streaming(session, key, model="gpt-4"):
|
||||||
|
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000")
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
|
{"role": "user", "content": f"Hello! {time.time()}"},
|
||||||
|
]
|
||||||
|
prompt_tokens = litellm.token_counter(model="gpt-35-turbo", messages=messages)
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
response = await client.chat.completions.create(**data)
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
async for chunk in response:
|
||||||
|
content += chunk.choices[0].delta.content or ""
|
||||||
|
|
||||||
|
print(f"content: {content}")
|
||||||
|
|
||||||
|
completion_tokens = litellm.token_counter(
|
||||||
|
model="gpt-35-turbo", text=content, count_response_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt_tokens, completion_tokens
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_key_update():
|
async def test_key_update():
|
||||||
|
@ -181,3 +217,49 @@ async def test_key_info():
|
||||||
random_key = key_gen["key"]
|
random_key = key_gen["key"]
|
||||||
status = await get_key_info(session=session, get_key=key, call_key=random_key)
|
status = await get_key_info(session=session, get_key=key, call_key=random_key)
|
||||||
assert status == 403
|
assert status == 403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_key_info_spend_values():
|
||||||
|
"""
|
||||||
|
- create key
|
||||||
|
- make completion call
|
||||||
|
- assert cost is expected value
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
## Test Spend Update ##
|
||||||
|
# completion
|
||||||
|
# response = await chat_completion(session=session, key=key)
|
||||||
|
# prompt_cost, completion_cost = litellm.cost_per_token(
|
||||||
|
# model="azure/gpt-35-turbo",
|
||||||
|
# prompt_tokens=response["usage"]["prompt_tokens"],
|
||||||
|
# completion_tokens=response["usage"]["completion_tokens"],
|
||||||
|
# )
|
||||||
|
# response_cost = prompt_cost + completion_cost
|
||||||
|
# await asyncio.sleep(5) # allow db log to be updated
|
||||||
|
# key_info = await get_key_info(session=session, get_key=key, call_key=key)
|
||||||
|
# print(
|
||||||
|
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
|
||||||
|
# )
|
||||||
|
# assert response_cost == key_info["info"]["spend"]
|
||||||
|
## streaming
|
||||||
|
key_gen = await generate_key(session=session, i=0)
|
||||||
|
new_key = key_gen["key"]
|
||||||
|
prompt_tokens, completion_tokens = await chat_completion_streaming(
|
||||||
|
session=session, key=new_key
|
||||||
|
)
|
||||||
|
print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
|
||||||
|
prompt_cost, completion_cost = litellm.cost_per_token(
|
||||||
|
model="azure/gpt-35-turbo",
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
response_cost = prompt_cost + completion_cost
|
||||||
|
await asyncio.sleep(5) # allow db log to be updated
|
||||||
|
key_info = await get_key_info(
|
||||||
|
session=session, get_key=new_key, call_key=new_key
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
|
||||||
|
)
|
||||||
|
assert response_cost == key_info["info"]["spend"]
|
||||||
|
|
|
@ -68,6 +68,7 @@ async def chat_completion(session, key):
|
||||||
|
|
||||||
if status != 200:
|
if status != 200:
|
||||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
125
ui/admin.py
125
ui/admin.py
|
@ -6,6 +6,9 @@ from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import base64, os, json, uuid, requests
|
import base64, os, json, uuid, requests
|
||||||
|
import pandas as pd
|
||||||
|
import plotly.express as px
|
||||||
|
import click
|
||||||
|
|
||||||
# Replace your_base_url with the actual URL where the proxy auth app is hosted
|
# Replace your_base_url with the actual URL where the proxy auth app is hosted
|
||||||
your_base_url = os.getenv("BASE_URL") # Example base URL
|
your_base_url = os.getenv("BASE_URL") # Example base URL
|
||||||
|
@ -75,7 +78,7 @@ def add_new_model():
|
||||||
and st.session_state.get("proxy_key", None) is None
|
and st.session_state.get("proxy_key", None) is None
|
||||||
):
|
):
|
||||||
st.warning(
|
st.warning(
|
||||||
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
|
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
model_name = st.text_input(
|
model_name = st.text_input(
|
||||||
|
@ -174,10 +177,70 @@ def list_models():
|
||||||
st.error(f"An error occurred while requesting models: {e}")
|
st.error(f"An error occurred while requesting models: {e}")
|
||||||
else:
|
else:
|
||||||
st.warning(
|
st.warning(
|
||||||
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
|
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def spend_per_key():
|
||||||
|
import streamlit as st
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# Check if the necessary configuration is available
|
||||||
|
if (
|
||||||
|
st.session_state.get("api_url", None) is not None
|
||||||
|
and st.session_state.get("proxy_key", None) is not None
|
||||||
|
):
|
||||||
|
# Make the GET request
|
||||||
|
try:
|
||||||
|
complete_url = ""
|
||||||
|
if isinstance(st.session_state["api_url"], str) and st.session_state[
|
||||||
|
"api_url"
|
||||||
|
].endswith("/"):
|
||||||
|
complete_url = f"{st.session_state['api_url']}/spend/keys"
|
||||||
|
else:
|
||||||
|
complete_url = f"{st.session_state['api_url']}/spend/keys"
|
||||||
|
response = requests.get(
|
||||||
|
complete_url,
|
||||||
|
headers={"Authorization": f"Bearer {st.session_state['proxy_key']}"},
|
||||||
|
)
|
||||||
|
# Check if the request was successful
|
||||||
|
if response.status_code == 200:
|
||||||
|
spend_per_key = response.json()
|
||||||
|
# Create DataFrame
|
||||||
|
spend_df = pd.DataFrame(spend_per_key)
|
||||||
|
|
||||||
|
# Display the spend per key as a graph
|
||||||
|
st.header("Spend ($) per API Key:")
|
||||||
|
top_10_df = spend_df.nlargest(10, "spend")
|
||||||
|
fig = px.bar(
|
||||||
|
top_10_df,
|
||||||
|
x="token",
|
||||||
|
y="spend",
|
||||||
|
title="Top 10 Spend per Key",
|
||||||
|
height=550, # Adjust the height
|
||||||
|
width=1200, # Adjust the width)
|
||||||
|
hover_data=["token", "spend", "user_id", "team_id"],
|
||||||
|
)
|
||||||
|
st.plotly_chart(fig)
|
||||||
|
|
||||||
|
# Display the spend per key as a table
|
||||||
|
st.write("Spend per Key - Full Table:")
|
||||||
|
st.table(spend_df)
|
||||||
|
|
||||||
|
else:
|
||||||
|
st.error(f"Failed to get models. Status code: {response.status_code}")
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"An error occurred while requesting models: {e}")
|
||||||
|
else:
|
||||||
|
st.warning(
|
||||||
|
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def spend_per_user():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def create_key():
|
def create_key():
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import json, requests, uuid
|
import json, requests, uuid
|
||||||
|
@ -187,7 +250,7 @@ def create_key():
|
||||||
and st.session_state.get("proxy_key", None) is None
|
and st.session_state.get("proxy_key", None) is None
|
||||||
):
|
):
|
||||||
st.warning(
|
st.warning(
|
||||||
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
|
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
duration = st.text_input("Duration - Can be in (h,m,s)", placeholder="1h")
|
duration = st.text_input("Duration - Can be in (h,m,s)", placeholder="1h")
|
||||||
|
@ -235,7 +298,7 @@ def update_config():
|
||||||
and st.session_state.get("proxy_key", None) is None
|
and st.session_state.get("proxy_key", None) is None
|
||||||
):
|
):
|
||||||
st.warning(
|
st.warning(
|
||||||
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page."
|
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
st.markdown("#### Alerting")
|
st.markdown("#### Alerting")
|
||||||
|
@ -324,19 +387,25 @@ def update_config():
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def admin_page(is_admin="NOT_GIVEN"):
|
def admin_page(is_admin="NOT_GIVEN", input_api_url=None, input_proxy_key=None):
|
||||||
# Display the form for the admin to set the proxy URL and allowed email subdomain
|
# Display the form for the admin to set the proxy URL and allowed email subdomain
|
||||||
|
st.set_page_config(
|
||||||
|
layout="wide", # Use "wide" layout for more space
|
||||||
|
)
|
||||||
st.header("Admin Configuration")
|
st.header("Admin Configuration")
|
||||||
st.session_state.setdefault("is_admin", is_admin)
|
st.session_state.setdefault("is_admin", is_admin)
|
||||||
# Add a navigation sidebar
|
# Add a navigation sidebar
|
||||||
st.sidebar.title("Navigation")
|
st.sidebar.title("Navigation")
|
||||||
|
|
||||||
page = st.sidebar.radio(
|
page = st.sidebar.radio(
|
||||||
"Go to",
|
"Go to",
|
||||||
(
|
(
|
||||||
"Connect to Proxy",
|
"Connect to Proxy",
|
||||||
|
"View Spend Per Key",
|
||||||
|
"View Spend Per User",
|
||||||
|
"List Models",
|
||||||
"Update Config",
|
"Update Config",
|
||||||
"Add Models",
|
"Add Models",
|
||||||
"List Models",
|
|
||||||
"Create Key",
|
"Create Key",
|
||||||
"End-User Auth",
|
"End-User Auth",
|
||||||
),
|
),
|
||||||
|
@ -344,16 +413,23 @@ def admin_page(is_admin="NOT_GIVEN"):
|
||||||
# Display different pages based on navigation selection
|
# Display different pages based on navigation selection
|
||||||
if page == "Connect to Proxy":
|
if page == "Connect to Proxy":
|
||||||
# Use text inputs with intermediary variables
|
# Use text inputs with intermediary variables
|
||||||
input_api_url = st.text_input(
|
if input_api_url is None:
|
||||||
"Proxy Endpoint",
|
input_api_url = st.text_input(
|
||||||
value=st.session_state.get("api_url", ""),
|
"Proxy Endpoint",
|
||||||
placeholder="http://0.0.0.0:8000",
|
value=st.session_state.get("api_url", ""),
|
||||||
)
|
placeholder="http://0.0.0.0:8000",
|
||||||
input_proxy_key = st.text_input(
|
)
|
||||||
"Proxy Key",
|
else:
|
||||||
value=st.session_state.get("proxy_key", ""),
|
st.session_state["api_url"] = input_api_url
|
||||||
placeholder="sk-...",
|
|
||||||
)
|
if input_proxy_key is None:
|
||||||
|
input_proxy_key = st.text_input(
|
||||||
|
"Proxy Key",
|
||||||
|
value=st.session_state.get("proxy_key", ""),
|
||||||
|
placeholder="sk-...",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
st.session_state["proxy_key"] = input_proxy_key
|
||||||
# When the "Save" button is clicked, update the session state
|
# When the "Save" button is clicked, update the session state
|
||||||
if st.button("Save"):
|
if st.button("Save"):
|
||||||
st.session_state["api_url"] = input_api_url
|
st.session_state["api_url"] = input_api_url
|
||||||
|
@ -369,6 +445,21 @@ def admin_page(is_admin="NOT_GIVEN"):
|
||||||
list_models()
|
list_models()
|
||||||
elif page == "Create Key":
|
elif page == "Create Key":
|
||||||
create_key()
|
create_key()
|
||||||
|
elif page == "View Spend Per Key":
|
||||||
|
spend_per_key()
|
||||||
|
elif page == "View Spend Per User":
|
||||||
|
spend_per_user()
|
||||||
|
|
||||||
|
|
||||||
admin_page()
|
# admin_page()
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option("--proxy_endpoint", type=str, help="Proxy Endpoint")
|
||||||
|
@click.option("--proxy_master_key", type=str, help="Proxy Master Key")
|
||||||
|
def main(proxy_endpoint, proxy_master_key):
|
||||||
|
admin_page(input_api_url=proxy_endpoint, input_proxy_key=proxy_master_key)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue