mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
feat(proxy_server.py): enable model aliases
This commit is contained in:
parent
356332d0da
commit
33e47dae8e
3 changed files with 16 additions and 9 deletions
|
@ -167,11 +167,14 @@ async def user_api_key_auth(request: Request):
|
|||
}
|
||||
)
|
||||
if valid_token:
|
||||
litellm.model_alias_map = valid_token.aliases
|
||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||
return
|
||||
else:
|
||||
data = await request.json()
|
||||
model = data.get("model", None)
|
||||
if model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[model]
|
||||
if model and model not in valid_token.models:
|
||||
raise Exception(f"Token not allowed to access model")
|
||||
return
|
||||
|
@ -233,9 +236,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
if litellm_settings:
|
||||
for key, value in litellm_settings.items():
|
||||
setattr(litellm, key, value)
|
||||
print(f"key: {key}; value: {value}")
|
||||
print(f"success callbacks: {litellm.success_callback}")
|
||||
|
||||
## MODEL LIST
|
||||
model_list = config.get('model_list', None)
|
||||
if model_list:
|
||||
|
@ -246,7 +246,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
|
||||
return router, model_list, server_settings
|
||||
|
||||
async def generate_key_helper_fn(duration_str: str, models: Optional[list]):
|
||||
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict):
|
||||
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||
def _duration_in_seconds(duration: str):
|
||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||
|
@ -269,18 +269,20 @@ async def generate_key_helper_fn(duration_str: str, models: Optional[list]):
|
|||
|
||||
duration = _duration_in_seconds(duration=duration_str)
|
||||
expires = datetime.utcnow() + timedelta(seconds=duration)
|
||||
aliases_json = json.dumps(aliases)
|
||||
try:
|
||||
db = prisma_client
|
||||
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||
verification_token_data = {
|
||||
"token": token,
|
||||
"expires": expires,
|
||||
"models": models
|
||||
"models": models,
|
||||
"aliases": aliases_json
|
||||
}
|
||||
print(f"verification_token_data: {verification_token_data}")
|
||||
new_verification_token = await db.litellm_verificationtoken.create( # type: ignore
|
||||
{**verification_token_data} # type: ignore
|
||||
)
|
||||
print(f"new_verification_token: {new_verification_token}")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
@ -530,8 +532,9 @@ async def generate_key_fn(request: Request):
|
|||
|
||||
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
|
||||
models = data.get("models", []) # Default to an empty list (meaning allow token to call all models)
|
||||
aliases = data.get("aliases", {}) # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
||||
if isinstance(models, list):
|
||||
response = await generate_key_helper_fn(duration_str=duration_str, models=models)
|
||||
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases)
|
||||
return {"key": response["token"], "expires": response["expires"]}
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
|
@ -12,4 +12,5 @@ model LiteLLM_VerificationToken {
|
|||
token String @unique
|
||||
expires DateTime
|
||||
models String[]
|
||||
aliases Json @default("{}")
|
||||
}
|
|
@ -132,9 +132,12 @@ class Router:
|
|||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None):
|
||||
"""
|
||||
Returns the deployment with the shortest queue
|
||||
Returns the deployment based on routing strategy
|
||||
"""
|
||||
logging.debug(f"self.healthy_deployments: {self.healthy_deployments}")
|
||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[
|
||||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
if self.routing_strategy == "least-busy":
|
||||
if len(self.healthy_deployments) > 0:
|
||||
for item in self.healthy_deployments:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue