From 33e47dae8e0053b4a5a80f58d5e15a99bd3a93ca Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 20 Nov 2023 16:50:45 -0800 Subject: [PATCH] feat(proxy_server.py): enable model aliases --- litellm/proxy/proxy_server.py | 17 ++++++++++------- litellm/proxy/schema.prisma | 1 + litellm/router.py | 7 +++++-- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 28af9238ee..74a47b91c5 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -167,11 +167,14 @@ async def user_api_key_auth(request: Request): } ) if valid_token: + litellm.model_alias_map = valid_token.aliases if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called return else: data = await request.json() model = data.get("model", None) + if model in litellm.model_alias_map: + model = litellm.model_alias_map[model] if model and model not in valid_token.models: raise Exception(f"Token not allowed to access model") return @@ -233,9 +236,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): if litellm_settings: for key, value in litellm_settings.items(): setattr(litellm, key, value) - print(f"key: {key}; value: {value}") - print(f"success callbacks: {litellm.success_callback}") - ## MODEL LIST model_list = config.get('model_list', None) if model_list: @@ -246,7 +246,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): return router, model_list, server_settings -async def generate_key_helper_fn(duration_str: str, models: Optional[list]): +async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict): token = f"sk-{secrets.token_urlsafe(16)}" def _duration_in_seconds(duration: str): match = re.match(r"(\d+)([smhd]?)", duration) @@ -269,18 +269,20 @@ async def generate_key_helper_fn(duration_str: str, models: Optional[list]): duration = _duration_in_seconds(duration=duration_str) expires = datetime.utcnow() + timedelta(seconds=duration) + aliases_json = json.dumps(aliases) try: db = prisma_client # Create a new verification token (you may want to enhance this logic based on your needs) verification_token_data = { "token": token, "expires": expires, - "models": models + "models": models, + "aliases": aliases_json } + print(f"verification_token_data: {verification_token_data}") new_verification_token = await db.litellm_verificationtoken.create( # type: ignore {**verification_token_data} # type: ignore ) - print(f"new_verification_token: {new_verification_token}") except Exception as e: traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) @@ -530,8 +532,9 @@ async def generate_key_fn(request: Request): duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided models = data.get("models", []) # Default to an empty list (meaning allow token to call all models) + aliases = data.get("aliases", {}) # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list) if isinstance(models, list): - response = await generate_key_helper_fn(duration_str=duration_str, models=models) + response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases) return {"key": response["token"], "expires": response["expires"]} else: raise HTTPException( diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 4d2837a8cd..9256a1d980 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -12,4 +12,5 @@ model LiteLLM_VerificationToken { token String @unique expires DateTime models String[] + aliases Json @default("{}") } \ No newline at end of file diff --git a/litellm/router.py b/litellm/router.py index 8ea76aa9c5..65c562c100 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -132,9 +132,12 @@ class Router: messages: Optional[List[Dict[str, str]]] = None, input: Optional[Union[str, List]] = None): """ - Returns the deployment with the shortest queue + Returns the deployment based on routing strategy """ - logging.debug(f"self.healthy_deployments: {self.healthy_deployments}") + if litellm.model_alias_map and model in litellm.model_alias_map: + model = litellm.model_alias_map[ + model + ] # update the model to the actual value if an alias has been passed in if self.routing_strategy == "least-busy": if len(self.healthy_deployments) > 0: for item in self.healthy_deployments: