fix(proxy_server.py): update db with master key if set, and fix tracking cost for azure models

This commit is contained in:
Krrish Dholakia 2023-12-02 15:57:53 -08:00
parent eb636d9429
commit 722c325503
3 changed files with 86 additions and 34 deletions

View file

@ -198,6 +198,17 @@ class ModelParams(BaseModel):
litellm_params: dict
model_info: Optional[dict]
class GenerateKeyRequest(BaseModel):
duration: str = "1h"
models: list = []
aliases: dict = {}
config: dict = {}
spend: int = 0
class GenerateKeyResponse(BaseModel):
key: str
expires: str
user_api_base = None
user_model = None
user_debug = False
@ -300,7 +311,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
else:
raise Exception(f"Invalid token")
except Exception as e:
print(f"An exception occurred - {e}")
print(f"An exception occurred - {traceback.format_exc()}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"error": "invalid user key"},
@ -378,8 +389,6 @@ def track_cost_callback(
end_time = None, # start/end time for completion
):
try:
# init logging config
print("in custom callback tracking cost", llm_model_list)
# check if it has collected an entire stream response
if "complete_streaming_response" in kwargs:
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
@ -393,18 +402,30 @@ def track_cost_callback(
)
print("streaming response_cost", response_cost)
# for non streaming responses
else:
# we pass the completion_response obj
if kwargs["stream"] != True:
input_text = kwargs.get("messages", "")
if isinstance(input_text, list):
input_text = "".join(m["content"] for m in input_text)
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
print("regular response_cost", response_cost)
print(f"metadata in kwargs: {kwargs}")
elif kwargs["stream"] is False: # regular response
input_text = kwargs.get("messages", "")
if isinstance(input_text, list):
input_text = "".join(m["content"] for m in input_text)
print(f"received completion response: {completion_response}")
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
print("regular response_cost", response_cost)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
if user_api_key:
asyncio.run(update_prisma_database(token=user_api_key, response_cost=response_cost))
# asyncio.run(update_prisma_database(user_api_key, response_cost))
# Create new event loop for async function execution in the new thread
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
# Run the async function using the newly created event loop
new_loop.run_until_complete(update_prisma_database(user_api_key, response_cost))
except Exception as e:
print(f"error in tracking cost callback - {str(e)}")
finally:
# Close the event loop after the task is done
new_loop.close()
# Ensure that there's no event loop set in this thread, which could interfere with future asyncio calls
asyncio.set_event_loop(None)
except Exception as e:
print(f"error in tracking cost callback - {str(e)}")
@ -419,6 +440,8 @@ async def update_prisma_database(token, response_cost):
}
)
print(f"existing spend: {existing_spend}")
if existing_spend is None:
existing_spend = 0
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend.spend + response_cost
@ -546,8 +569,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
run_ollama_serve()
return router, model_list, general_settings
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict, spend: float):
token = f"sk-{secrets.token_urlsafe(16)}"
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]):
if token is None:
token = f"sk-{secrets.token_urlsafe(16)}"
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
if not match:
@ -566,9 +590,13 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict,
return value * 86400
else:
raise ValueError("Unsupported duration unit")
duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration)
if duration_str is None: # allow tokens that never expire
expires = None
else:
duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration)
aliases_json = json.dumps(aliases)
config_json = json.dumps(config)
try:
@ -582,9 +610,14 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict,
"config": config_json,
"spend": spend
}
print(f"verification_token_data: {verification_token_data}")
new_verification_token = await db.litellm_verificationtoken.create( # type: ignore
{**verification_token_data} # type: ignore
new_verification_token = await db.litellm_verificationtoken.upsert( # type: ignore
where={
'token': token,
},
data={
"create": {**verification_token_data}, #type: ignore
"update": {} # don't do anything if it already exists
}
)
except Exception as e:
traceback.print_exc()
@ -744,12 +777,16 @@ def litellm_completion(*args, **kwargs):
@app.on_event("startup")
async def startup_event():
global prisma_client
global prisma_client, master_key
import json
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
initialize(**worker_config)
if prisma_client:
await prisma_client.connect()
if prisma_client is not None and master_key is not None:
# add master key to db
await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key)
@app.on_event("shutdown")
async def shutdown_event():
@ -940,25 +977,39 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
#### KEY MANAGEMENT ####
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
async def generate_key_fn(request: Request):
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
async def generate_key_fn(request: Request, data: GenerateKeyRequest):
"""
Generate an API key based on the provided data.
Parameters:
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)**
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
Returns:
- key: The generated api key
- expires: Datetime object for when key expires.
"""
data = await request.json()
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
models = data.get("models", []) # Default to an empty list (meaning allow token to call all models)
aliases = data.get("aliases", {}) # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
config = data.get("config", {})
spend = data.get("spend", 0)
duration_str = data.duration # Default to 1 hour if duration is not provided
models = data.models # Default to an empty list (meaning allow token to call all models)
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
config = data.config
spend = data.spend
if isinstance(models, list):
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend)
return {"key": response["token"], "expires": response["expires"]}
return GenerateKeyResponse(key=response["token"], expires=response["expires"])
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "models param must be a list"},
)
@router.post("/key/delete", dependencies=[Depends(user_api_key_auth)])
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def delete_key_fn(request: Request):
try:
data = await request.json()
@ -980,7 +1031,7 @@ async def delete_key_fn(request: Request):
detail={"error": str(e)},
)
@router.get("/key/info", dependencies=[Depends(user_api_key_auth)])
@router.get("/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def info_key_fn(key: str = fastapi.Query(..., description="Key in the request parameters")):
global prisma_client
try:
@ -1058,6 +1109,7 @@ async def model_info(request: Request):
],
object="list",
)
#### EXPERIMENTAL QUEUING ####
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
async def async_queue_request(request: Request):

View file

@ -11,7 +11,7 @@ generator client {
model LiteLLM_VerificationToken {
token String @unique
spend Float @default(0.0)
expires DateTime
expires DateTime?
models String[]
aliases Json @default("{}")
config Json @default("{}")

View file

@ -1656,7 +1656,6 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
prompt_tokens_cost_usd_dollar = 0
completion_tokens_cost_usd_dollar = 0
model_cost_ref = litellm.model_cost
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
azure_llms = {
"gpt-35-turbo": "azure/gpt-3.5-turbo",
@ -1688,6 +1687,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
else:
# calculate average input cost, azure/gpt-deployments can potentially go here if users don't specify, gpt-4, gpt-3.5-turbo. LLMs litellm knows
input_cost_sum = 0