diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 51dd4eaaf..2b8235464 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -302,9 +302,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap llm_model_list = model_list print("\n new llm router model list", llm_model_list) if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called - return { - "api_key": valid_token.token - } + return_dict = {"api_key": valid_token.token} + if valid_token.user_id: + return_dict["user_id"] = valid_token.user_id + return return_dict else: data = await request.json() model = data.get("model", None) @@ -312,9 +313,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap model = litellm.model_alias_map[model] if model and model not in valid_token.models: raise Exception(f"Token not allowed to access model") - return { - "api_key": valid_token.token - } + return_dict = {"api_key": valid_token.token} + if valid_token.user_id: + return_dict["user_id"] = valid_token.user_id + return return_dict else: raise Exception(f"Invalid token") except Exception as e: @@ -834,6 +836,8 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key data = ast.literal_eval(body_str) except: data = json.loads(body_str) + + data["user_id"] = user_api_key_dict.get("user_id", None) data["model"] = ( general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args @@ -882,6 +886,8 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap or data["model"] # default passed in http request ) + data["user_id"] = user_api_key_dict.get("user_id", None) + if "metadata" in data: data["metadata"]["user_api_key"] = user_api_key_dict["api_key"] else: @@ -943,6 +949,7 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap body = await request.body() data = orjson.loads(body) + data["user_id"] = user_api_key_dict.get("user_id", None) data["model"] = ( general_settings.get("embedding_model", None) # server default or user_model # model name passed via cli args