forked from phoenix/litellm-mirror
(feat)proxy: readon config per request
This commit is contained in:
parent
6117bcb19f
commit
8c98a2c899
1 changed files with 20 additions and 5 deletions
|
@ -181,6 +181,17 @@ async def user_api_key_auth(request: Request):
|
||||||
)
|
)
|
||||||
if valid_token:
|
if valid_token:
|
||||||
litellm.model_alias_map = valid_token.aliases
|
litellm.model_alias_map = valid_token.aliases
|
||||||
|
config = valid_token.config
|
||||||
|
if config != {}:
|
||||||
|
global llm_router
|
||||||
|
model_list = config.get("model_list", [])
|
||||||
|
if llm_router == None:
|
||||||
|
llm_router = litellm.Router(
|
||||||
|
model_list=model_list
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
llm_router.model_list = model_list
|
||||||
|
print("\n new llm router model list", llm_router.model_list)
|
||||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
@ -204,6 +215,7 @@ def prisma_setup(database_url: Optional[str]):
|
||||||
global prisma_client
|
global prisma_client
|
||||||
if database_url:
|
if database_url:
|
||||||
import os
|
import os
|
||||||
|
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||||
os.environ["DATABASE_URL"] = database_url
|
os.environ["DATABASE_URL"] = database_url
|
||||||
subprocess.run(['pip', 'install', 'prisma'])
|
subprocess.run(['pip', 'install', 'prisma'])
|
||||||
subprocess.run(['python3', '-m', 'pip', 'install', 'prisma'])
|
subprocess.run(['python3', '-m', 'pip', 'install', 'prisma'])
|
||||||
|
@ -277,13 +289,13 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
for model in model_list:
|
for model in model_list:
|
||||||
print(f"\033[32m {model.get('model_name', '')}\033[0m")
|
print(f"\033[32m {model.get('model_name', '')}\033[0m")
|
||||||
litellm_model_name = model["litellm_params"]["model"]
|
litellm_model_name = model["litellm_params"]["model"]
|
||||||
print(f"litellm_model_name: {litellm_model_name}")
|
# print(f"litellm_model_name: {litellm_model_name}")
|
||||||
if "ollama" in litellm_model_name:
|
if "ollama" in litellm_model_name:
|
||||||
run_ollama_serve()
|
run_ollama_serve()
|
||||||
|
|
||||||
return router, model_list, server_settings
|
return router, model_list, server_settings
|
||||||
|
|
||||||
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict):
|
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict):
|
||||||
token = f"sk-{secrets.token_urlsafe(16)}"
|
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||||
def _duration_in_seconds(duration: str):
|
def _duration_in_seconds(duration: str):
|
||||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||||
|
@ -307,6 +319,7 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict)
|
||||||
duration = _duration_in_seconds(duration=duration_str)
|
duration = _duration_in_seconds(duration=duration_str)
|
||||||
expires = datetime.utcnow() + timedelta(seconds=duration)
|
expires = datetime.utcnow() + timedelta(seconds=duration)
|
||||||
aliases_json = json.dumps(aliases)
|
aliases_json = json.dumps(aliases)
|
||||||
|
config_json = json.dumps(config)
|
||||||
try:
|
try:
|
||||||
db = prisma_client
|
db = prisma_client
|
||||||
# Create a new verification token (you may want to enhance this logic based on your needs)
|
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||||
|
@ -314,7 +327,8 @@ async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict)
|
||||||
"token": token,
|
"token": token,
|
||||||
"expires": expires,
|
"expires": expires,
|
||||||
"models": models,
|
"models": models,
|
||||||
"aliases": aliases_json
|
"aliases": aliases_json,
|
||||||
|
"config": config_json
|
||||||
}
|
}
|
||||||
print(f"verification_token_data: {verification_token_data}")
|
print(f"verification_token_data: {verification_token_data}")
|
||||||
new_verification_token = await db.litellm_verificationtoken.create( # type: ignore
|
new_verification_token = await db.litellm_verificationtoken.create( # type: ignore
|
||||||
|
@ -571,8 +585,9 @@ async def generate_key_fn(request: Request):
|
||||||
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
|
duration_str = data.get("duration", "1h") # Default to 1 hour if duration is not provided
|
||||||
models = data.get("models", []) # Default to an empty list (meaning allow token to call all models)
|
models = data.get("models", []) # Default to an empty list (meaning allow token to call all models)
|
||||||
aliases = data.get("aliases", {}) # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
aliases = data.get("aliases", {}) # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
||||||
|
config = data.get("config", {})
|
||||||
if isinstance(models, list):
|
if isinstance(models, list):
|
||||||
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases)
|
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config)
|
||||||
return {"key": response["token"], "expires": response["expires"]}
|
return {"key": response["token"], "expires": response["expires"]}
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -595,7 +610,7 @@ async def async_chat_completions(request: Request):
|
||||||
or data["model"] # default passed in http request
|
or data["model"] # default passed in http request
|
||||||
)
|
)
|
||||||
data["call_type"] = "chat_completion"
|
data["call_type"] = "chat_completion"
|
||||||
data["llm_router"] = llm_router
|
data["llm_router"] = llm_router # this is dynamic - we should load the llm_router from the user_api_key_auth
|
||||||
job = request_queue.enqueue(litellm.litellm_queue_completion, **data)
|
job = request_queue.enqueue(litellm.litellm_queue_completion, **data)
|
||||||
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
|
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue