diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a852009bb..4e14b2013 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -198,6 +198,17 @@ class ModelParams(BaseModel): litellm_params: dict model_info: Optional[dict] +class GenerateKeyRequest(BaseModel): + duration: str = "1h" + models: list = [] + aliases: dict = {} + config: dict = {} + spend: int = 0 + +class GenerateKeyResponse(BaseModel): + key: str + expires: str + user_api_base = None user_model = None user_debug = False @@ -300,7 +311,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap else: raise Exception(f"Invalid token") except Exception as e: - print(f"An exception occurred - {e}") + print(f"An exception occurred - {traceback.format_exc()}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail={"error": "invalid user key"}, @@ -378,8 +389,6 @@ def track_cost_callback( end_time = None, # start/end time for completion ): try: - # init logging config - print("in custom callback tracking cost", llm_model_list) # check if it has collected an entire stream response if "complete_streaming_response" in kwargs: # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost @@ -393,18 +402,30 @@ def track_cost_callback( ) print("streaming response_cost", response_cost) # for non streaming responses - else: - # we pass the completion_response obj - if kwargs["stream"] != True: - input_text = kwargs.get("messages", "") - if isinstance(input_text, list): - input_text = "".join(m["content"] for m in input_text) - response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text) - print("regular response_cost", response_cost) - print(f"metadata in kwargs: {kwargs}") + elif kwargs["stream"] is False: # regular response + input_text = kwargs.get("messages", "") + if isinstance(input_text, list): + input_text = "".join(m["content"] for m in input_text) + print(f"received completion response: {completion_response}") + response_cost = litellm.completion_cost(completion_response=completion_response, completion=input_text) + print("regular response_cost", response_cost) user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) if user_api_key: - asyncio.run(update_prisma_database(token=user_api_key, response_cost=response_cost)) + # asyncio.run(update_prisma_database(user_api_key, response_cost)) + # Create new event loop for async function execution in the new thread + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + + try: + # Run the async function using the newly created event loop + new_loop.run_until_complete(update_prisma_database(user_api_key, response_cost)) + except Exception as e: + print(f"error in tracking cost callback - {str(e)}") + finally: + # Close the event loop after the task is done + new_loop.close() + # Ensure that there's no event loop set in this thread, which could interfere with future asyncio calls + asyncio.set_event_loop(None) except Exception as e: print(f"error in tracking cost callback - {str(e)}") @@ -419,6 +440,8 @@ async def update_prisma_database(token, response_cost): } ) print(f"existing spend: {existing_spend}") + if existing_spend is None: + existing_spend = 0 # Calculate the new cost by adding the existing cost and response_cost new_spend = existing_spend.spend + response_cost @@ -546,8 +569,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): run_ollama_serve() return router, model_list, general_settings -async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict, spend: float): - token = f"sk-{secrets.token_urlsafe(16)}" +async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]): + if token is None: + token = f"sk-{secrets.token_urlsafe(16)}" def _duration_in_seconds(duration: str): match = re.match(r"(\d+)([smhd]?)", duration) if not match: @@ -566,9 +590,13 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, return value * 86400 else: raise ValueError("Unsupported duration unit") - - duration = _duration_in_seconds(duration=duration_str) - expires = datetime.utcnow() + timedelta(seconds=duration) + + if duration_str is None: # allow tokens that never expire + expires = None + else: + duration = _duration_in_seconds(duration=duration_str) + expires = datetime.utcnow() + timedelta(seconds=duration) + aliases_json = json.dumps(aliases) config_json = json.dumps(config) try: @@ -582,9 +610,14 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, "config": config_json, "spend": spend } - print(f"verification_token_data: {verification_token_data}") - new_verification_token = await db.litellm_verificationtoken.create( # type: ignore - {**verification_token_data} # type: ignore + new_verification_token = await db.litellm_verificationtoken.upsert( # type: ignore + where={ + 'token': token, + }, + data={ + "create": {**verification_token_data}, #type: ignore + "update": {} # don't do anything if it already exists + } ) except Exception as e: traceback.print_exc() @@ -744,12 +777,16 @@ def litellm_completion(*args, **kwargs): @app.on_event("startup") async def startup_event(): - global prisma_client + global prisma_client, master_key import json worker_config = json.loads(os.getenv("WORKER_CONFIG")) initialize(**worker_config) if prisma_client: await prisma_client.connect() + + if prisma_client is not None and master_key is not None: + # add master key to db + await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key) @app.on_event("shutdown") async def shutdown_event(): @@ -940,25 +977,39 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap #### KEY MANAGEMENT #### -@router.post("/key/generate", dependencies=[Depends(user_api_key_auth)]) -async def generate_key_fn(request: Request): +@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse) +async def generate_key_fn(request: Request, data: GenerateKeyRequest): + """ + Generate an API key based on the provided data. + + Parameters: + - duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)** + - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) + - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models + - config: Optional[dict] - any key-specific configs, overrides config in config.yaml + - spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend + + Returns: + - key: The generated api key + - expires: Datetime object for when key expires. + """ data = await request.json() - duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided - models = data.get("models", []) # Default to an empty list (meaning allow token to call all models) - aliases = data.get("aliases", {}) # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list) - config = data.get("config", {}) - spend = data.get("spend", 0) + duration_str = data.duration # Default to 1 hour if duration is not provided + models = data.models # Default to an empty list (meaning allow token to call all models) + aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list) + config = data.config + spend = data.spend if isinstance(models, list): response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend) - return {"key": response["token"], "expires": response["expires"]} + return GenerateKeyResponse(key=response["token"], expires=response["expires"]) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail={"error": "models param must be a list"}, ) -@router.post("/key/delete", dependencies=[Depends(user_api_key_auth)]) +@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) async def delete_key_fn(request: Request): try: data = await request.json() @@ -980,7 +1031,7 @@ async def delete_key_fn(request: Request): detail={"error": str(e)}, ) -@router.get("/key/info", dependencies=[Depends(user_api_key_auth)]) +@router.get("/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) async def info_key_fn(key: str = fastapi.Query(..., description="Key in the request parameters")): global prisma_client try: @@ -1058,6 +1109,7 @@ async def model_info(request: Request): ], object="list", ) + #### EXPERIMENTAL QUEUING #### @router.post("/queue/request", dependencies=[Depends(user_api_key_auth)]) async def async_queue_request(request: Request): diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 414d23277..889f6ad78 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -11,7 +11,7 @@ generator client { model LiteLLM_VerificationToken { token String @unique spend Float @default(0.0) - expires DateTime + expires DateTime? models String[] aliases Json @default("{}") config Json @default("{}") diff --git a/litellm/utils.py b/litellm/utils.py index 4ae862314..c3ff1d7f0 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1656,7 +1656,6 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): prompt_tokens_cost_usd_dollar = 0 completion_tokens_cost_usd_dollar = 0 model_cost_ref = litellm.model_cost - # see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models azure_llms = { "gpt-35-turbo": "azure/gpt-3.5-turbo", @@ -1688,6 +1687,7 @@ def cost_per_token(model="", prompt_tokens=0, completion_tokens=0): completion_tokens_cost_usd_dollar = ( model_cost_ref[model]["output_cost_per_token"] * completion_tokens ) + return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar else: # calculate average input cost, azure/gpt-deployments can potentially go here if users don't specify, gpt-4, gpt-3.5-turbo. LLMs litellm knows input_cost_sum = 0