forked from phoenix/litellm-mirror
fix(proxy_server.py): Passing user IDs to OpenAI to identify abusive virtual keys
This commit is contained in:
parent
492c9043f6
commit
5200818af1
1 changed files with 13 additions and 6 deletions
|
@ -302,9 +302,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
llm_model_list = model_list
|
||||
print("\n new llm router model list", llm_model_list)
|
||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||
return {
|
||||
"api_key": valid_token.token
|
||||
}
|
||||
return_dict = {"api_key": valid_token.token}
|
||||
if valid_token.user_id:
|
||||
return_dict["user_id"] = valid_token.user_id
|
||||
return return_dict
|
||||
else:
|
||||
data = await request.json()
|
||||
model = data.get("model", None)
|
||||
|
@ -312,9 +313,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
model = litellm.model_alias_map[model]
|
||||
if model and model not in valid_token.models:
|
||||
raise Exception(f"Token not allowed to access model")
|
||||
return {
|
||||
"api_key": valid_token.token
|
||||
}
|
||||
return_dict = {"api_key": valid_token.token}
|
||||
if valid_token.user_id:
|
||||
return_dict["user_id"] = valid_token.user_id
|
||||
return return_dict
|
||||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
except Exception as e:
|
||||
|
@ -834,6 +836,8 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
data = ast.literal_eval(body_str)
|
||||
except:
|
||||
data = json.loads(body_str)
|
||||
|
||||
data["user_id"] = user_api_key_dict.get("user_id", None)
|
||||
data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
|
@ -882,6 +886,8 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
or data["model"] # default passed in http request
|
||||
)
|
||||
|
||||
data["user_id"] = user_api_key_dict.get("user_id", None)
|
||||
|
||||
if "metadata" in data:
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
|
||||
else:
|
||||
|
@ -943,6 +949,7 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
|
|||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
data["user_id"] = user_api_key_dict.get("user_id", None)
|
||||
data["model"] = (
|
||||
general_settings.get("embedding_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue