feat(proxy_server.py): enable model aliases

This commit is contained in:
Krrish Dholakia 2023-11-20 16:50:45 -08:00
parent 356332d0da
commit 33e47dae8e
3 changed files with 16 additions and 9 deletions

View file

@ -167,11 +167,14 @@ async def user_api_key_auth(request: Request):
}
)
if valid_token:
litellm.model_alias_map = valid_token.aliases
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
return
else:
data = await request.json()
model = data.get("model", None)
if model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model")
return
@ -233,9 +236,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if litellm_settings:
for key, value in litellm_settings.items():
setattr(litellm, key, value)
print(f"key: {key}; value: {value}")
print(f"success callbacks: {litellm.success_callback}")
## MODEL LIST
model_list = config.get('model_list', None)
if model_list:
@ -246,7 +246,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
return router, model_list, server_settings
async def generate_key_helper_fn(duration_str: str, models: Optional[list]):
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict):
token = f"sk-{secrets.token_urlsafe(16)}"
def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration)
@ -269,18 +269,20 @@ async def generate_key_helper_fn(duration_str: str, models: Optional[list]):
duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration)
aliases_json = json.dumps(aliases)
try:
db = prisma_client
# Create a new verification token (you may want to enhance this logic based on your needs)
verification_token_data = {
"token": token,
"expires": expires,
"models": models
"models": models,
"aliases": aliases_json
}
print(f"verification_token_data: {verification_token_data}")
new_verification_token = await db.litellm_verificationtoken.create( # type: ignore
{**verification_token_data} # type: ignore
)
print(f"new_verification_token: {new_verification_token}")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
@ -530,8 +532,9 @@ async def generate_key_fn(request: Request):
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
models = data.get("models", []) # Default to an empty list (meaning allow token to call all models)
aliases = data.get("aliases", {}) # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
if isinstance(models, list):
response = await generate_key_helper_fn(duration_str=duration_str, models=models)
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases)
return {"key": response["token"], "expires": response["expires"]}
else:
raise HTTPException(

View file

@ -12,4 +12,5 @@ model LiteLLM_VerificationToken {
token String @unique
expires DateTime
models String[]
aliases Json @default("{}")
}

View file

@ -132,9 +132,12 @@ class Router:
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None):
"""
Returns the deployment with the shortest queue
Returns the deployment based on routing strategy
"""
logging.debug(f"self.healthy_deployments: {self.healthy_deployments}")
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
if self.routing_strategy == "least-busy":
if len(self.healthy_deployments) > 0:
for item in self.healthy_deployments: