forked from phoenix/litellm-mirror
fix(proxy_server.py): update db with master key if set, and fix tracking cost for azure models
This commit is contained in:
parent
eb636d9429
commit
722c325503
3 changed files with 86 additions and 34 deletions
|
@ -198,6 +198,17 @@ class ModelParams(BaseModel):
|
|||
litellm_params: dict
|
||||
model_info: Optional[dict]
|
||||
|
||||
class GenerateKeyRequest(BaseModel):
|
||||
duration: str = "1h"
|
||||
models: list = []
|
||||
aliases: dict = {}
|
||||
config: dict = {}
|
||||
spend: int = 0
|
||||
|
||||
class GenerateKeyResponse(BaseModel):
|
||||
key: str
|
||||
expires: str
|
||||
|
||||
user_api_base = None
|
||||
user_model = None
|
||||
user_debug = False
|
||||
|
@ -300,7 +311,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
except Exception as e:
|
||||
print(f"An exception occurred - {e}")
|
||||
print(f"An exception occurred - {traceback.format_exc()}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail={"error": "invalid user key"},
|
||||
|
@ -378,8 +389,6 @@ def track_cost_callback(
|
|||
end_time = None, # start/end time for completion
|
||||
):
|
||||
try:
|
||||
# init logging config
|
||||
print("in custom callback tracking cost", llm_model_list)
|
||||
# check if it has collected an entire stream response
|
||||
if "complete_streaming_response" in kwargs:
|
||||
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||
|
@ -393,18 +402,30 @@ def track_cost_callback(
|
|||
)
|
||||
print("streaming response_cost", response_cost)
|
||||
# for non streaming responses
|
||||
else:
|
||||
# we pass the completion_response obj
|
||||
if kwargs["stream"] != True:
|
||||
input_text = kwargs.get("messages", "")
|
||||
if isinstance(input_text, list):
|
||||
input_text = "".join(m["content"] for m in input_text)
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
|
||||
print("regular response_cost", response_cost)
|
||||
print(f"metadata in kwargs: {kwargs}")
|
||||
elif kwargs["stream"] is False: # regular response
|
||||
input_text = kwargs.get("messages", "")
|
||||
if isinstance(input_text, list):
|
||||
input_text = "".join(m["content"] for m in input_text)
|
||||
print(f"received completion response: {completion_response}")
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text)
|
||||
print("regular response_cost", response_cost)
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||
if user_api_key:
|
||||
asyncio.run(update_prisma_database(token=user_api_key, response_cost=response_cost))
|
||||
# asyncio.run(update_prisma_database(user_api_key, response_cost))
|
||||
# Create new event loop for async function execution in the new thread
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
|
||||
try:
|
||||
# Run the async function using the newly created event loop
|
||||
new_loop.run_until_complete(update_prisma_database(user_api_key, response_cost))
|
||||
except Exception as e:
|
||||
print(f"error in tracking cost callback - {str(e)}")
|
||||
finally:
|
||||
# Close the event loop after the task is done
|
||||
new_loop.close()
|
||||
# Ensure that there's no event loop set in this thread, which could interfere with future asyncio calls
|
||||
asyncio.set_event_loop(None)
|
||||
except Exception as e:
|
||||
print(f"error in tracking cost callback - {str(e)}")
|
||||
|
||||
|
@ -419,6 +440,8 @@ async def update_prisma_database(token, response_cost):
|
|||
}
|
||||
)
|
||||
print(f"existing spend: {existing_spend}")
|
||||
if existing_spend is None:
|
||||
existing_spend = 0
|
||||
# Calculate the new cost by adding the existing cost and response_cost
|
||||
new_spend = existing_spend.spend + response_cost
|
||||
|
||||
|
@ -546,8 +569,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
run_ollama_serve()
|
||||
return router, model_list, general_settings
|
||||
|
||||
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict, spend: float):
|
||||
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]):
|
||||
if token is None:
|
||||
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||
def _duration_in_seconds(duration: str):
|
||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||
if not match:
|
||||
|
@ -566,9 +590,13 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict,
|
|||
return value * 86400
|
||||
else:
|
||||
raise ValueError("Unsupported duration unit")
|
||||
|
||||
duration = _duration_in_seconds(duration=duration_str)
|
||||
expires = datetime.utcnow() + timedelta(seconds=duration)
|
||||
|
||||
if duration_str is None: # allow tokens that never expire
|
||||
expires = None
|
||||
else:
|
||||
duration = _duration_in_seconds(duration=duration_str)
|
||||
expires = datetime.utcnow() + timedelta(seconds=duration)
|
||||
|
||||
aliases_json = json.dumps(aliases)
|
||||
config_json = json.dumps(config)
|
||||
try:
|
||||
|
@ -582,9 +610,14 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict,
|
|||
"config": config_json,
|
||||
"spend": spend
|
||||
}
|
||||
print(f"verification_token_data: {verification_token_data}")
|
||||
new_verification_token = await db.litellm_verificationtoken.create( # type: ignore
|
||||
{**verification_token_data} # type: ignore
|
||||
new_verification_token = await db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': token,
|
||||
},
|
||||
data={
|
||||
"create": {**verification_token_data}, #type: ignore
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
@ -744,12 +777,16 @@ def litellm_completion(*args, **kwargs):
|
|||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client
|
||||
global prisma_client, master_key
|
||||
import json
|
||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||
initialize(**worker_config)
|
||||
if prisma_client:
|
||||
await prisma_client.connect()
|
||||
|
||||
if prisma_client is not None and master_key is not None:
|
||||
# add master key to db
|
||||
await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key)
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
|
@ -940,25 +977,39 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
|
|||
|
||||
#### KEY MANAGEMENT ####
|
||||
|
||||
@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)])
|
||||
async def generate_key_fn(request: Request):
|
||||
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
|
||||
async def generate_key_fn(request: Request, data: GenerateKeyRequest):
|
||||
"""
|
||||
Generate an API key based on the provided data.
|
||||
|
||||
Parameters:
|
||||
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)**
|
||||
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)
|
||||
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
|
||||
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
|
||||
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
|
||||
|
||||
Returns:
|
||||
- key: The generated api key
|
||||
- expires: Datetime object for when key expires.
|
||||
"""
|
||||
data = await request.json()
|
||||
|
||||
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
|
||||
models = data.get("models", []) # Default to an empty list (meaning allow token to call all models)
|
||||
aliases = data.get("aliases", {}) # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
||||
config = data.get("config", {})
|
||||
spend = data.get("spend", 0)
|
||||
duration_str = data.duration # Default to 1 hour if duration is not provided
|
||||
models = data.models # Default to an empty list (meaning allow token to call all models)
|
||||
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
||||
config = data.config
|
||||
spend = data.spend
|
||||
if isinstance(models, list):
|
||||
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend)
|
||||
return {"key": response["token"], "expires": response["expires"]}
|
||||
return GenerateKeyResponse(key=response["token"], expires=response["expires"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "models param must be a list"},
|
||||
)
|
||||
|
||||
@router.post("/key/delete", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def delete_key_fn(request: Request):
|
||||
try:
|
||||
data = await request.json()
|
||||
|
@ -980,7 +1031,7 @@ async def delete_key_fn(request: Request):
|
|||
detail={"error": str(e)},
|
||||
)
|
||||
|
||||
@router.get("/key/info", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.get("/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def info_key_fn(key: str = fastapi.Query(..., description="Key in the request parameters")):
|
||||
global prisma_client
|
||||
try:
|
||||
|
@ -1058,6 +1109,7 @@ async def model_info(request: Request):
|
|||
],
|
||||
object="list",
|
||||
)
|
||||
|
||||
#### EXPERIMENTAL QUEUING ####
|
||||
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
||||
async def async_queue_request(request: Request):
|
||||
|
|
|
@ -11,7 +11,7 @@ generator client {
|
|||
model LiteLLM_VerificationToken {
|
||||
token String @unique
|
||||
spend Float @default(0.0)
|
||||
expires DateTime
|
||||
expires DateTime?
|
||||
models String[]
|
||||
aliases Json @default("{}")
|
||||
config Json @default("{}")
|
||||
|
|
|
@ -1656,7 +1656,6 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
|
|||
prompt_tokens_cost_usd_dollar = 0
|
||||
completion_tokens_cost_usd_dollar = 0
|
||||
model_cost_ref = litellm.model_cost
|
||||
|
||||
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
||||
azure_llms = {
|
||||
"gpt-35-turbo": "azure/gpt-3.5-turbo",
|
||||
|
@ -1688,6 +1687,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0):
|
|||
completion_tokens_cost_usd_dollar = (
|
||||
model_cost_ref[model]["output_cost_per_token"] * completion_tokens
|
||||
)
|
||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||
else:
|
||||
# calculate average input cost, azure/gpt-deployments can potentially go here if users don't specify, gpt-4, gpt-3.5-turbo. LLMs litellm knows
|
||||
input_cost_sum = 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue