Merge branch 'main' into litellm_reset_key_budget

This commit is contained in:
Krish Dholakia 2024-01-23 18:10:32 -08:00 committed by GitHub
commit 9784d03d65
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 492 additions and 78 deletions

View file

@ -598,9 +598,9 @@ async def track_cost_callback(
end_time=end_time,
)
else:
if (
kwargs["stream"] != True
or kwargs.get("complete_streaming_response", None) is not None
if kwargs["stream"] != True or (
kwargs["stream"] == True
and kwargs.get("complete_streaming_response") in kwargs
):
raise Exception(
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
@ -701,6 +701,7 @@ async def update_database(
valid_token.spend = new_spend
user_api_key_cache.set_cache(key=token, value=valid_token)
### UPDATE SPEND LOGS ###
async def _insert_spend_log_to_db():
# Helper to generate payload to log
verbose_proxy_logger.debug("inserting spend log to db")
@ -1438,6 +1439,28 @@ async def async_data_generator(response, user_api_key_dict):
yield f"data: {str(e)}\n\n"
def select_data_generator(response, user_api_key_dict):
try:
# since boto3 - sagemaker does not support async calls, we should use a sync data_generator
if (
hasattr(response, "custom_llm_provider")
and response.custom_llm_provider == "sagemaker"
):
return data_generator(
response=response,
)
else:
# default to async_data_generator
return async_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
except:
# worst case - use async_data_generator
return async_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
def get_litellm_model_info(model: dict = {}):
model_info = model.get("model_info", {})
model_to_lookup = model.get("litellm_params", {}).get("model", None)
@ -1679,11 +1702,12 @@ async def completion(
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
custom_headers = {"x-litellm-model-id": model_id}
selected_data_generator = select_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
return StreamingResponse(
async_data_generator(
user_api_key_dict=user_api_key_dict,
response=response,
),
selected_data_generator,
media_type="text/event-stream",
headers=custom_headers,
)
@ -1841,11 +1865,12 @@ async def chat_completion(
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
custom_headers = {"x-litellm-model-id": model_id}
selected_data_generator = select_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
return StreamingResponse(
async_data_generator(
user_api_key_dict=user_api_key_dict,
response=response,
),
selected_data_generator,
media_type="text/event-stream",
headers=custom_headers,
)
@ -2305,6 +2330,94 @@ async def info_key_fn(
)
@router.get(
"/spend/keys",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def spend_key_fn():
"""
View all keys created, ordered by spend
Example Request:
```
curl -X GET "http://0.0.0.0:8000/spend/keys" \
-H "Authorization: Bearer sk-1234"
```
"""
global prisma_client
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
key_info = await prisma_client.get_data(table_name="key", query_type="find_all")
return key_info
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
@router.get(
"/spend/logs",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def view_spend_logs(
request_id: Optional[str] = fastapi.Query(
default=None,
description="request_id to get spend logs for specific request_id. If none passed then pass spend logs for all requests",
),
):
"""
View all spend logs, if request_id is provided, only logs for that request_id will be returned
Example Request for all logs
```
curl -X GET "http://0.0.0.0:8000/spend/logs" \
-H "Authorization: Bearer sk-1234"
```
Example Request for specific request_id
```
curl -X GET "http://0.0.0.0:8000/spend/logs?request_id=chatcmpl-6dcb2540-d3d7-4e49-bb27-291f863f112e" \
-H "Authorization: Bearer sk-1234"
```
"""
global prisma_client
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
spend_logs = []
if request_id is not None:
spend_log = await prisma_client.get_data(
table_name="spend",
query_type="find_unique",
request_id=request_id,
)
return [spend_log]
else:
spend_logs = await prisma_client.get_data(
table_name="spend", query_type="find_all"
)
return spend_logs
return None
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
#### USER MANAGEMENT ####
@router.post(
"/user/new",