fix(proxy_server.py): Passing user IDs to OpenAI to identify abusive virtual keys

This commit is contained in:
Krrish Dholakia 2023-12-02 19:55:11 -08:00
parent 492c9043f6
commit 5200818af1

View file

@ -302,9 +302,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
llm_model_list = model_list llm_model_list = model_list
print("\n new llm router model list", llm_model_list) print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
return { return_dict = {"api_key": valid_token.token}
"api_key": valid_token.token if valid_token.user_id:
} return_dict["user_id"] = valid_token.user_id
return return_dict
else: else:
data = await request.json() data = await request.json()
model = data.get("model", None) model = data.get("model", None)
@ -312,9 +313,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
model = litellm.model_alias_map[model] model = litellm.model_alias_map[model]
if model and model not in valid_token.models: if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model") raise Exception(f"Token not allowed to access model")
return { return_dict = {"api_key": valid_token.token}
"api_key": valid_token.token if valid_token.user_id:
} return_dict["user_id"] = valid_token.user_id
return return_dict
else: else:
raise Exception(f"Invalid token") raise Exception(f"Invalid token")
except Exception as e: except Exception as e:
@ -834,6 +836,8 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
data = ast.literal_eval(body_str) data = ast.literal_eval(body_str)
except: except:
data = json.loads(body_str) data = json.loads(body_str)
data["user_id"] = user_api_key_dict.get("user_id", None)
data["model"] = ( data["model"] = (
general_settings.get("completion_model", None) # server default general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args or user_model # model name passed via cli args
@ -882,6 +886,8 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
or data["model"] # default passed in http request or data["model"] # default passed in http request
) )
data["user_id"] = user_api_key_dict.get("user_id", None)
if "metadata" in data: if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"] data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
else: else:
@ -943,6 +949,7 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
body = await request.body() body = await request.body()
data = orjson.loads(body) data = orjson.loads(body)
data["user_id"] = user_api_key_dict.get("user_id", None)
data["model"] = ( data["model"] = (
general_settings.get("embedding_model", None) # server default general_settings.get("embedding_model", None) # server default
or user_model # model name passed via cli args or user_model # model name passed via cli args