(feat)proxy: readon config per request

This commit is contained in:
ishaan-jaff 2023-11-21 16:26:05 -08:00
parent 6117bcb19f
commit 8c98a2c899

View file

@ -181,6 +181,17 @@ async def user_api_key_auth(request: Request):
) )
if valid_token: if valid_token:
litellm.model_alias_map = valid_token.aliases litellm.model_alias_map = valid_token.aliases
config = valid_token.config
if config != {}:
global llm_router
model_list = config.get("model_list", [])
if llm_router == None:
llm_router = litellm.Router(
model_list=model_list
)
else:
llm_router.model_list = model_list
print("\n new llm router model list", llm_router.model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
return return
else: else:
@ -204,6 +215,7 @@ def prisma_setup(database_url: Optional[str]):
global prisma_client global prisma_client
if database_url: if database_url:
import os import os
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
os.environ["DATABASE_URL"] = database_url os.environ["DATABASE_URL"] = database_url
subprocess.run(['pip', 'install', 'prisma']) subprocess.run(['pip', 'install', 'prisma'])
subprocess.run(['python3', '-m', 'pip', 'install', 'prisma']) subprocess.run(['python3', '-m', 'pip', 'install', 'prisma'])
@ -277,13 +289,13 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
for model in model_list: for model in model_list:
print(f"\033[32m {model.get('model_name', '')}\033[0m") print(f"\033[32m {model.get('model_name', '')}\033[0m")
litellm_model_name = model["litellm_params"]["model"] litellm_model_name = model["litellm_params"]["model"]
print(f"litellm_model_name: {litellm_model_name}") # print(f"litellm_model_name: {litellm_model_name}")
if "ollama" in litellm_model_name: if "ollama" in litellm_model_name:
run_ollama_serve() run_ollama_serve()
return router, model_list, server_settings return router, model_list, server_settings
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict): async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict):
token = f"sk-{secrets.token_urlsafe(16)}" token = f"sk-{secrets.token_urlsafe(16)}"
def _duration_in_seconds(duration: str): def _duration_in_seconds(duration: str):
match = re.match(r"(\d+)([smhd]?)", duration) match = re.match(r"(\d+)([smhd]?)", duration)
@ -307,6 +319,7 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict)
duration = _duration_in_seconds(duration=duration_str) duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration) expires = datetime.utcnow() + timedelta(seconds=duration)
aliases_json = json.dumps(aliases) aliases_json = json.dumps(aliases)
config_json = json.dumps(config)
try: try:
db = prisma_client db = prisma_client
# Create a new verification token (you may want to enhance this logic based on your needs) # Create a new verification token (you may want to enhance this logic based on your needs)
@ -314,7 +327,8 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict)
"token": token, "token": token,
"expires": expires, "expires": expires,
"models": models, "models": models,
"aliases": aliases_json "aliases": aliases_json,
"config": config_json
} }
print(f"verification_token_data: {verification_token_data}") print(f"verification_token_data: {verification_token_data}")
new_verification_token = await db.litellm_verificationtoken.create( # type: ignore new_verification_token = await db.litellm_verificationtoken.create( # type: ignore
@ -571,8 +585,9 @@ async def generate_key_fn(request: Request):
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
models = data.get("models", []) # Default to an empty list (meaning allow token to call all models) models = data.get("models", []) # Default to an empty list (meaning allow token to call all models)
aliases = data.get("aliases", {}) # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list) aliases = data.get("aliases", {}) # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
config = data.get("config", {})
if isinstance(models, list): if isinstance(models, list):
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases) response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config)
return {"key": response["token"], "expires": response["expires"]} return {"key": response["token"], "expires": response["expires"]}
else: else:
raise HTTPException( raise HTTPException(
@ -595,7 +610,7 @@ async def async_chat_completions(request: Request):
or data["model"] # default passed in http request or data["model"] # default passed in http request
) )
data["call_type"] = "chat_completion" data["call_type"] = "chat_completion"
data["llm_router"] = llm_router data["llm_router"] = llm_router # this is dynamic - we should load the llm_router from the user_api_key_auth
job = request_queue.enqueue(litellm.litellm_queue_completion, **data) job = request_queue.enqueue(litellm.litellm_queue_completion, **data)
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"} return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
pass pass