diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 383ee7cfb7..52f4febd42 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -181,6 +181,17 @@ async def user_api_key_auth(request: Request): ) if valid_token: litellm.model_alias_map = valid_token.aliases + config = valid_token.config + if config != {}: + global llm_router + model_list = config.get("model_list", []) + if llm_router == None: + llm_router = litellm.Router( + model_list=model_list + ) + else: + llm_router.model_list = model_list + print("\n new llm router model list", llm_router.model_list) if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called return else: @@ -204,6 +215,7 @@ def prisma_setup(database_url: Optional[str]): global prisma_client if database_url: import os + print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") os.environ["DATABASE_URL"] = database_url subprocess.run(['pip', 'install', 'prisma']) subprocess.run(['python3', '-m', 'pip', 'install', 'prisma']) @@ -277,13 +289,13 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): for model in model_list: print(f"\033[32m {model.get('model_name', '')}\033[0m") litellm_model_name = model["litellm_params"]["model"] - print(f"litellm_model_name: {litellm_model_name}") + # print(f"litellm_model_name: {litellm_model_name}") if "ollama" in litellm_model_name: run_ollama_serve() return router, model_list, server_settings -async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict): +async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict): token = f"sk-{secrets.token_urlsafe(16)}" def _duration_in_seconds(duration: str): match = re.match(r"(\d+)([smhd]?)", duration) @@ -307,6 +319,7 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict) duration = _duration_in_seconds(duration=duration_str) expires = datetime.utcnow() + timedelta(seconds=duration) aliases_json = json.dumps(aliases) + config_json = json.dumps(config) try: db = prisma_client # Create a new verification token (you may want to enhance this logic based on your needs) @@ -314,7 +327,8 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict) "token": token, "expires": expires, "models": models, - "aliases": aliases_json + "aliases": aliases_json, + "config": config_json } print(f"verification_token_data: {verification_token_data}") new_verification_token = await db.litellm_verificationtoken.create( # type: ignore @@ -571,8 +585,9 @@ async def generate_key_fn(request: Request): duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided models = data.get("models", []) # Default to an empty list (meaning allow token to call all models) aliases = data.get("aliases", {}) # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list) + config = data.get("config", {}) if isinstance(models, list): - response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases) + response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config) return {"key": response["token"], "expires": response["expires"]} else: raise HTTPException( @@ -595,7 +610,7 @@ async def async_chat_completions(request: Request): or data["model"] # default passed in http request ) data["call_type"] = "chat_completion" - data["llm_router"] = llm_router + data["llm_router"] = llm_router # this is dynamic - we should load the llm_router from the user_api_key_auth job = request_queue.enqueue(litellm.litellm_queue_completion, **data) return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"} pass