forked from phoenix/litellm-mirror
fix(proxy_server.py): update db with master key if set, and fix tracking cost for azure models
This commit is contained in:
parent
eb636d9429
commit
722c325503
3 changed files with 86 additions and 34 deletions
|
@ -198,6 +198,17 @@ class ModelParams(BaseModel):
|
||||||
litellm_params: dict
|
litellm_params: dict
|
||||||
model_info: Optional[dict]
|
model_info: Optional[dict]
|
||||||
|
|
||||||
|
class GenerateKeyRequest(BaseModel):
|
||||||
|
duration: str = "1h"
|
||||||
|
models: list = []
|
||||||
|
aliases: dict = {}
|
||||||
|
config: dict = {}
|
||||||
|
spend: int = 0
|
||||||
|
|
||||||
|
class GenerateKeyResponse(BaseModel):
|
||||||
|
key: str
|
||||||
|
expires: str
|
||||||
|
|
||||||
user_api_base = None
|
user_api_base = None
|
||||||
user_model = None
|
user_model = None
|
||||||
user_debug = False
|
user_debug = False
|
||||||
|
@ -300,7 +311,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid token")
|
raise Exception(f"Invalid token")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An exception occurred - {e}")
|
print(f"An exception occurred - {traceback.format_exc()}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail={"error": "invalid user key"},
|
detail={"error": "invalid user key"},
|
||||||
|
@ -378,8 +389,6 @@ def track_cost_callback(
|
||||||
end_time = None, # start/end time for completion
|
end_time = None, # start/end time for completion
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# init logging config
|
|
||||||
print("in custom callback tracking cost", llm_model_list)
|
|
||||||
# check if it has collected an entire stream response
|
# check if it has collected an entire stream response
|
||||||
if "complete_streaming_response" in kwargs:
|
if "complete_streaming_response" in kwargs:
|
||||||
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||||
|
@ -393,18 +402,30 @@ def track_cost_callback(
|
||||||
)
|
)
|
||||||
print("streaming response_cost", response_cost)
|
print("streaming response_cost", response_cost)
|
||||||
# for non streaming responses
|
# for non streaming responses
|
||||||
else:
|
elif kwargs["stream"] is False: # regular response
|
||||||
# we pass the completion_response obj
|
input_text = kwargs.get("messages", "")
|
||||||
if kwargs["stream"] != True:
|
if isinstance(input_text, list):
|
||||||
input_text = kwargs.get("messages", "")
|
input_text = "".join(m["content"] for m in input_text)
|
||||||
if isinstance(input_text, list):
|
print(f"received completion response: {completion_response}")
|
||||||
input_text = "".join(m["content"] for m in input_text)
|
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
|
||||||
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
|
print("regular response_cost", response_cost)
|
||||||
print("regular response_cost", response_cost)
|
|
||||||
print(f"metadata in kwargs: {kwargs}")
|
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||||
if user_api_key:
|
if user_api_key:
|
||||||
asyncio.run(update_prisma_database(token=user_api_key, response_cost=response_cost))
|
# asyncio.run(update_prisma_database(user_api_key, response_cost))
|
||||||
|
# Create new event loop for async function execution in the new thread
|
||||||
|
new_loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(new_loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run the async function using the newly created event loop
|
||||||
|
new_loop.run_until_complete(update_prisma_database(user_api_key, response_cost))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"error in tracking cost callback - {str(e)}")
|
||||||
|
finally:
|
||||||
|
# Close the event loop after the task is done
|
||||||
|
new_loop.close()
|
||||||
|
# Ensure that there's no event loop set in this thread, which could interfere with future asyncio calls
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"error in tracking cost callback - {str(e)}")
|
print(f"error in tracking cost callback - {str(e)}")
|
||||||
|
|
||||||
|
@ -419,6 +440,8 @@ async def update_prisma_database(token, response_cost):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
print(f"existing spend: {existing_spend}")
|
print(f"existing spend: {existing_spend}")
|
||||||
|
if existing_spend is None:
|
||||||
|
existing_spend = 0
|
||||||
# Calculate the new cost by adding the existing cost and response_cost
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
new_spend = existing_spend.spend + response_cost
|
new_spend = existing_spend.spend + response_cost
|
||||||
|
|
||||||
|
@ -546,8 +569,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
run_ollama_serve()
|
run_ollama_serve()
|
||||||
return router, model_list, general_settings
|
return router, model_list, general_settings
|
||||||
|
|
||||||
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict, spend: float):
|
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]):
|
||||||
token = f"sk-{secrets.token_urlsafe(16)}"
|
if token is None:
|
||||||
|
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||||
def _duration_in_seconds(duration: str):
|
def _duration_in_seconds(duration: str):
|
||||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||||
if not match:
|
if not match:
|
||||||
|
@ -566,9 +590,13 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict,
|
||||||
return value * 86400
|
return value * 86400
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported duration unit")
|
raise ValueError("Unsupported duration unit")
|
||||||
|
|
||||||
duration = _duration_in_seconds(duration=duration_str)
|
if duration_str is None: # allow tokens that never expire
|
||||||
expires = datetime.utcnow() + timedelta(seconds=duration)
|
expires = None
|
||||||
|
else:
|
||||||
|
duration = _duration_in_seconds(duration=duration_str)
|
||||||
|
expires = datetime.utcnow() + timedelta(seconds=duration)
|
||||||
|
|
||||||
aliases_json = json.dumps(aliases)
|
aliases_json = json.dumps(aliases)
|
||||||
config_json = json.dumps(config)
|
config_json = json.dumps(config)
|
||||||
try:
|
try:
|
||||||
|
@ -582,9 +610,14 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict,
|
||||||
"config": config_json,
|
"config": config_json,
|
||||||
"spend": spend
|
"spend": spend
|
||||||
}
|
}
|
||||||
print(f"verification_token_data: {verification_token_data}")
|
new_verification_token = await db.litellm_verificationtoken.upsert( # type: ignore
|
||||||
new_verification_token = await db.litellm_verificationtoken.create( # type: ignore
|
where={
|
||||||
{**verification_token_data} # type: ignore
|
'token': token,
|
||||||
|
},
|
||||||
|
data={
|
||||||
|
"create": {**verification_token_data}, #type: ignore
|
||||||
|
"update": {} # don't do anything if it already exists
|
||||||
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
@ -744,12 +777,16 @@ def litellm_completion(*args, **kwargs):
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
global prisma_client
|
global prisma_client, master_key
|
||||||
import json
|
import json
|
||||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||||
initialize(**worker_config)
|
initialize(**worker_config)
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
await prisma_client.connect()
|
await prisma_client.connect()
|
||||||
|
|
||||||
|
if prisma_client is not None and master_key is not None:
|
||||||
|
# add master key to db
|
||||||
|
await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key)
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
@app.on_event("shutdown")
|
||||||
async def shutdown_event():
|
async def shutdown_event():
|
||||||
|
@ -940,25 +977,39 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
|
||||||
|
|
||||||
#### KEY MANAGEMENT ####
|
#### KEY MANAGEMENT ####
|
||||||
|
|
||||||
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
|
||||||
async def generate_key_fn(request: Request):
|
async def generate_key_fn(request: Request, data: GenerateKeyRequest):
|
||||||
|
"""
|
||||||
|
Generate an API key based on the provided data.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)**
|
||||||
|
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)
|
||||||
|
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
|
||||||
|
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
|
||||||
|
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- key: The generated api key
|
||||||
|
- expires: Datetime object for when key expires.
|
||||||
|
"""
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
|
||||||
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
|
duration_str = data.duration # Default to 1 hour if duration is not provided
|
||||||
models = data.get("models", []) # Default to an empty list (meaning allow token to call all models)
|
models = data.models # Default to an empty list (meaning allow token to call all models)
|
||||||
aliases = data.get("aliases", {}) # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
||||||
config = data.get("config", {})
|
config = data.config
|
||||||
spend = data.get("spend", 0)
|
spend = data.spend
|
||||||
if isinstance(models, list):
|
if isinstance(models, list):
|
||||||
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend)
|
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend)
|
||||||
return {"key": response["token"], "expires": response["expires"]}
|
return GenerateKeyResponse(key=response["token"], expires=response["expires"])
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail={"error": "models param must be a list"},
|
detail={"error": "models param must be a list"},
|
||||||
)
|
)
|
||||||
|
|
||||||
@router.post("/key/delete", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
async def delete_key_fn(request: Request):
|
async def delete_key_fn(request: Request):
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
@ -980,7 +1031,7 @@ async def delete_key_fn(request: Request):
|
||||||
detail={"error": str(e)},
|
detail={"error": str(e)},
|
||||||
)
|
)
|
||||||
|
|
||||||
@router.get("/key/info", dependencies=[Depends(user_api_key_auth)])
|
@router.get("/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
async def info_key_fn(key: str = fastapi.Query(..., description="Key in the request parameters")):
|
async def info_key_fn(key: str = fastapi.Query(..., description="Key in the request parameters")):
|
||||||
global prisma_client
|
global prisma_client
|
||||||
try:
|
try:
|
||||||
|
@ -1058,6 +1109,7 @@ async def model_info(request: Request):
|
||||||
],
|
],
|
||||||
object="list",
|
object="list",
|
||||||
)
|
)
|
||||||
|
|
||||||
#### EXPERIMENTAL QUEUING ####
|
#### EXPERIMENTAL QUEUING ####
|
||||||
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def async_queue_request(request: Request):
|
async def async_queue_request(request: Request):
|
||||||
|
|
|
@ -11,7 +11,7 @@ generator client {
|
||||||
model LiteLLM_VerificationToken {
|
model LiteLLM_VerificationToken {
|
||||||
token String @unique
|
token String @unique
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
expires DateTime
|
expires DateTime?
|
||||||
models String[]
|
models String[]
|
||||||
aliases Json @default("{}")
|
aliases Json @default("{}")
|
||||||
config Json @default("{}")
|
config Json @default("{}")
|
||||||
|
|
|
@ -1656,7 +1656,6 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
|
||||||
prompt_tokens_cost_usd_dollar = 0
|
prompt_tokens_cost_usd_dollar = 0
|
||||||
completion_tokens_cost_usd_dollar = 0
|
completion_tokens_cost_usd_dollar = 0
|
||||||
model_cost_ref = litellm.model_cost
|
model_cost_ref = litellm.model_cost
|
||||||
|
|
||||||
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
||||||
azure_llms = {
|
azure_llms = {
|
||||||
"gpt-35-turbo": "azure/gpt-3.5-turbo",
|
"gpt-35-turbo": "azure/gpt-3.5-turbo",
|
||||||
|
@ -1688,6 +1687,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
|
||||||
completion_tokens_cost_usd_dollar = (
|
completion_tokens_cost_usd_dollar = (
|
||||||
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
||||||
)
|
)
|
||||||
|
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||||
else:
|
else:
|
||||||
# calculate average input cost, azure/gpt-deployments can potentially go here if users don't specify, gpt-4, gpt-3.5-turbo. LLMs litellm knows
|
# calculate average input cost, azure/gpt-deployments can potentially go here if users don't specify, gpt-4, gpt-3.5-turbo. LLMs litellm knows
|
||||||
input_cost_sum = 0
|
input_cost_sum = 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue