forked from phoenix/litellm-mirror
fix(proxy_server.py): Passing user IDs to OpenAI to identify abusive virtual keys
This commit is contained in:
parent
492c9043f6
commit
5200818af1
1 changed files with 13 additions and 6 deletions
|
@ -302,9 +302,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
llm_model_list = model_list
|
llm_model_list = model_list
|
||||||
print("\n new llm router model list", llm_model_list)
|
print("\n new llm router model list", llm_model_list)
|
||||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||||
return {
|
return_dict = {"api_key": valid_token.token}
|
||||||
"api_key": valid_token.token
|
if valid_token.user_id:
|
||||||
}
|
return_dict["user_id"] = valid_token.user_id
|
||||||
|
return return_dict
|
||||||
else:
|
else:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
model = data.get("model", None)
|
model = data.get("model", None)
|
||||||
|
@ -312,9 +313,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
model = litellm.model_alias_map[model]
|
model = litellm.model_alias_map[model]
|
||||||
if model and model not in valid_token.models:
|
if model and model not in valid_token.models:
|
||||||
raise Exception(f"Token not allowed to access model")
|
raise Exception(f"Token not allowed to access model")
|
||||||
return {
|
return_dict = {"api_key": valid_token.token}
|
||||||
"api_key": valid_token.token
|
if valid_token.user_id:
|
||||||
}
|
return_dict["user_id"] = valid_token.user_id
|
||||||
|
return return_dict
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid token")
|
raise Exception(f"Invalid token")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -834,6 +836,8 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
||||||
data = ast.literal_eval(body_str)
|
data = ast.literal_eval(body_str)
|
||||||
except:
|
except:
|
||||||
data = json.loads(body_str)
|
data = json.loads(body_str)
|
||||||
|
|
||||||
|
data["user_id"] = user_api_key_dict.get("user_id", None)
|
||||||
data["model"] = (
|
data["model"] = (
|
||||||
general_settings.get("completion_model", None) # server default
|
general_settings.get("completion_model", None) # server default
|
||||||
or user_model # model name passed via cli args
|
or user_model # model name passed via cli args
|
||||||
|
@ -882,6 +886,8 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
||||||
or data["model"] # default passed in http request
|
or data["model"] # default passed in http request
|
||||||
)
|
)
|
||||||
|
|
||||||
|
data["user_id"] = user_api_key_dict.get("user_id", None)
|
||||||
|
|
||||||
if "metadata" in data:
|
if "metadata" in data:
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
|
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
|
||||||
else:
|
else:
|
||||||
|
@ -943,6 +949,7 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
data = orjson.loads(body)
|
data = orjson.loads(body)
|
||||||
|
|
||||||
|
data["user_id"] = user_api_key_dict.get("user_id", None)
|
||||||
data["model"] = (
|
data["model"] = (
|
||||||
general_settings.get("embedding_model", None) # server default
|
general_settings.get("embedding_model", None) # server default
|
||||||
or user_model # model name passed via cli args
|
or user_model # model name passed via cli args
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue