mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(proxy_server.py): enable model aliases
This commit is contained in:
parent
356332d0da
commit
33e47dae8e
3 changed files with 16 additions and 9 deletions
|
@ -167,11 +167,14 @@ async def user_api_key_auth(request: Request):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if valid_token:
|
if valid_token:
|
||||||
|
litellm.model_alias_map = valid_token.aliases
|
||||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
model = data.get("model", None)
|
model = data.get("model", None)
|
||||||
|
if model in litellm.model_alias_map:
|
||||||
|
model = litellm.model_alias_map[model]
|
||||||
if model and model not in valid_token.models:
|
if model and model not in valid_token.models:
|
||||||
raise Exception(f"Token not allowed to access model")
|
raise Exception(f"Token not allowed to access model")
|
||||||
return
|
return
|
||||||
|
@ -233,9 +236,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
if litellm_settings:
|
if litellm_settings:
|
||||||
for key, value in litellm_settings.items():
|
for key, value in litellm_settings.items():
|
||||||
setattr(litellm, key, value)
|
setattr(litellm, key, value)
|
||||||
print(f"key: {key}; value: {value}")
|
|
||||||
print(f"success callbacks: {litellm.success_callback}")
|
|
||||||
|
|
||||||
## MODEL LIST
|
## MODEL LIST
|
||||||
model_list = config.get('model_list', None)
|
model_list = config.get('model_list', None)
|
||||||
if model_list:
|
if model_list:
|
||||||
|
@ -246,7 +246,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
|
|
||||||
return router, model_list, server_settings
|
return router, model_list, server_settings
|
||||||
|
|
||||||
async def generate_key_helper_fn(duration_str: str, models: Optional[list]):
|
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict):
|
||||||
token = f"sk-{secrets.token_urlsafe(16)}"
|
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||||
def _duration_in_seconds(duration: str):
|
def _duration_in_seconds(duration: str):
|
||||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||||
|
@ -269,18 +269,20 @@ async def generate_key_helper_fn(duration_str: str, models: Optional[list]):
|
||||||
|
|
||||||
duration = _duration_in_seconds(duration=duration_str)
|
duration = _duration_in_seconds(duration=duration_str)
|
||||||
expires = datetime.utcnow() + timedelta(seconds=duration)
|
expires = datetime.utcnow() + timedelta(seconds=duration)
|
||||||
|
aliases_json = json.dumps(aliases)
|
||||||
try:
|
try:
|
||||||
db = prisma_client
|
db = prisma_client
|
||||||
# Create a new verification token (you may want to enhance this logic based on your needs)
|
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||||
verification_token_data = {
|
verification_token_data = {
|
||||||
"token": token,
|
"token": token,
|
||||||
"expires": expires,
|
"expires": expires,
|
||||||
"models": models
|
"models": models,
|
||||||
|
"aliases": aliases_json
|
||||||
}
|
}
|
||||||
|
print(f"verification_token_data: {verification_token_data}")
|
||||||
new_verification_token = await db.litellm_verificationtoken.create( # type: ignore
|
new_verification_token = await db.litellm_verificationtoken.create( # type: ignore
|
||||||
{**verification_token_data} # type: ignore
|
{**verification_token_data} # type: ignore
|
||||||
)
|
)
|
||||||
print(f"new_verification_token: {new_verification_token}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
@ -530,8 +532,9 @@ async def generate_key_fn(request: Request):
|
||||||
|
|
||||||
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
|
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
|
||||||
models = data.get("models", []) # Default to an empty list (meaning allow token to call all models)
|
models = data.get("models", []) # Default to an empty list (meaning allow token to call all models)
|
||||||
|
aliases = data.get("aliases", {}) # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
||||||
if isinstance(models, list):
|
if isinstance(models, list):
|
||||||
response = await generate_key_helper_fn(duration_str=duration_str, models=models)
|
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases)
|
||||||
return {"key": response["token"], "expires": response["expires"]}
|
return {"key": response["token"], "expires": response["expires"]}
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
@ -12,4 +12,5 @@ model LiteLLM_VerificationToken {
|
||||||
token String @unique
|
token String @unique
|
||||||
expires DateTime
|
expires DateTime
|
||||||
models String[]
|
models String[]
|
||||||
|
aliases Json @default("{}")
|
||||||
}
|
}
|
|
@ -132,9 +132,12 @@ class Router:
|
||||||
messages: Optional[List[Dict[str, str]]] = None,
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
input: Optional[Union[str, List]] = None):
|
input: Optional[Union[str, List]] = None):
|
||||||
"""
|
"""
|
||||||
Returns the deployment with the shortest queue
|
Returns the deployment based on routing strategy
|
||||||
"""
|
"""
|
||||||
logging.debug(f"self.healthy_deployments: {self.healthy_deployments}")
|
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||||
|
model = litellm.model_alias_map[
|
||||||
|
model
|
||||||
|
] # update the model to the actual value if an alias has been passed in
|
||||||
if self.routing_strategy == "least-busy":
|
if self.routing_strategy == "least-busy":
|
||||||
if len(self.healthy_deployments) > 0:
|
if len(self.healthy_deployments) > 0:
|
||||||
for item in self.healthy_deployments:
|
for item in self.healthy_deployments:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue