fix(proxy_server.py): don't auto-create user when creating key

This commit is contained in:
Krrish Dholakia 2024-03-27 16:48:57 -07:00
parent 9b7383ac67
commit a408c46a67
3 changed files with 41 additions and 34 deletions

View file

@ -1,21 +1,22 @@
model_list:
- model_name: fake_openai
- model_name: fake-openai-endpoint
litellm_params:
model: openai/my-fake-model
api_key: my-fake-key
api_base: http://0.0.0.0:8080
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo-1106
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
cache: true
cache_params:
type: redis
callbacks: ["batch_redis_requests"]
# litellm_settings:
# cache: true
# cache_params:
# type: redis
# callbacks: ["batch_redis_requests"]
# success_callbacks: ["langfuse"]
general_settings:
master_key: sk-1234
# disable_spend_logs: true
database_url: "postgresql://neondb_owner:hz8tyUlJ5ivV@ep-cool-sunset-a5ywubeh.us-east-2.aws.neon.tech/neondb?sslmode=require"

View file

@ -305,7 +305,7 @@ class GenerateKeyResponse(GenerateKeyRequest):
key: str
key_name: Optional[str] = None
expires: Optional[datetime]
user_id: str
user_id: Optional[str] = None
@root_validator(pre=True)
def set_model_info(cls, values):

View file

@ -2320,8 +2320,6 @@ async def generate_key_helper_fn(
permissions_json = json.dumps(permissions)
metadata_json = json.dumps(metadata)
model_max_budget_json = json.dumps(model_max_budget)
user_id = user_id or str(uuid.uuid4())
user_role = user_role or "app_user"
tpm_limit = tpm_limit
rpm_limit = rpm_limit
@ -2409,20 +2407,23 @@ async def generate_key_helper_fn(
):
saved_token["expires"] = saved_token["expires"].isoformat()
if prisma_client is not None:
## CREATE USER (If necessary)
if query_type == "insert_data":
user_row = await prisma_client.insert_data(
data=user_data, table_name="user"
)
## use default user model list if no key-specific model list provided
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
key_data["models"] = user_row.models
elif query_type == "update_data":
user_row = await prisma_client.update_data(
data=user_data,
table_name="user",
update_key_values=update_key_values,
)
if (
table_name is None or table_name == "user"
): # do not auto-create users for `/key/generate`
## CREATE USER (If necessary)
if query_type == "insert_data":
user_row = await prisma_client.insert_data(
data=user_data, table_name="user"
)
## use default user model list if no key-specific model list provided
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
key_data["models"] = user_row.models
elif query_type == "update_data":
user_row = await prisma_client.update_data(
data=user_data,
table_name="user",
update_key_values=update_key_values,
)
if user_id == litellm_proxy_budget_name or (
table_name is not None and table_name == "user"
):
@ -2435,16 +2436,19 @@ async def generate_key_helper_fn(
verbose_proxy_logger.debug("prisma_client: Creating Key= %s", key_data)
await prisma_client.insert_data(data=key_data, table_name="key")
elif custom_db_client is not None:
## CREATE USER (If necessary)
verbose_proxy_logger.debug("CustomDBClient: Creating User= %s", user_data)
user_row = await custom_db_client.insert_data(
value=user_data, table_name="user"
)
if user_row is None:
# GET USER ROW
user_row = await custom_db_client.get_data(
key=user_id, table_name="user"
if table_name is None or table_name == "user":
## CREATE USER (If necessary)
verbose_proxy_logger.debug(
"CustomDBClient: Creating User= %s", user_data
)
user_row = await custom_db_client.insert_data(
value=user_data, table_name="user"
)
if user_row is None:
# GET USER ROW
user_row = await custom_db_client.get_data(
key=user_id, table_name="user" # type: ignore
)
## use default user model list if no key-specific model list provided
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
@ -4051,7 +4055,7 @@ async def generate_key_fn(
if "budget_duration" in data_json:
data_json["key_budget_duration"] = data_json.pop("budget_duration", None)
response = await generate_key_helper_fn(**data_json)
response = await generate_key_helper_fn(**data_json, table_name="key")
return GenerateKeyResponse(**response)
except Exception as e:
traceback.print_exc()
@ -5042,6 +5046,8 @@ async def new_user(data: NewUserRequest):
param="user_role",
code=status.HTTP_400_BAD_REQUEST,
)
if "user_id" in data_json and data_json["user_id"] is None:
data_json["user_id"] = str(uuid.uuid4())
response = await generate_key_helper_fn(**data_json)
return NewUserResponse(
key=response["token"],