forked from phoenix/litellm-mirror
feat(proxy_server.py): add new /key/update
endpoint
This commit is contained in:
parent
23a4ac724b
commit
693292a64c
3 changed files with 74 additions and 3 deletions
|
@ -121,11 +121,24 @@ class GenerateKeyRequest(LiteLLMBase):
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
|
|
||||||
|
class UpdateKeyRequest(LiteLLMBase):
|
||||||
|
key: str
|
||||||
|
duration: Optional[str] = None
|
||||||
|
models: Optional[list] = None
|
||||||
|
aliases: Optional[dict] = None
|
||||||
|
config: Optional[dict] = None
|
||||||
|
spend: Optional[float] = None
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
max_parallel_requests: Optional[int] = None
|
||||||
|
|
||||||
class GenerateKeyResponse(LiteLLMBase):
|
class GenerateKeyResponse(LiteLLMBase):
|
||||||
key: str
|
key: str
|
||||||
expires: datetime
|
expires: datetime
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class _DeleteKeyObject(LiteLLMBase):
|
class _DeleteKeyObject(LiteLLMBase):
|
||||||
key: str
|
key: str
|
||||||
|
|
||||||
|
|
|
@ -252,15 +252,15 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
if api_key is None: # only require api key if master key is set
|
if api_key is None: # only require api key if master key is set
|
||||||
raise Exception(f"No api key passed in.")
|
raise Exception(f"No api key passed in.")
|
||||||
|
|
||||||
route = request.url.path
|
route: str = request.url.path
|
||||||
|
|
||||||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||||
is_master_key_valid = secrets.compare_digest(api_key, master_key)
|
is_master_key_valid = secrets.compare_digest(api_key, master_key)
|
||||||
if is_master_key_valid:
|
if is_master_key_valid:
|
||||||
return UserAPIKeyAuth(api_key=master_key)
|
return UserAPIKeyAuth(api_key=master_key)
|
||||||
|
|
||||||
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
|
if route.startswith("/key/") and not is_master_key_valid:
|
||||||
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
|
raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys")
|
||||||
|
|
||||||
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||||
raise Exception("No connected db.")
|
raise Exception("No connected db.")
|
||||||
|
@ -676,6 +676,8 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
|
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_verification_token(tokens: List):
|
async def delete_verification_token(tokens: List):
|
||||||
global prisma_client
|
global prisma_client
|
||||||
try:
|
try:
|
||||||
|
@ -1140,6 +1142,30 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
|
||||||
response = await generate_key_helper_fn(**data_json)
|
response = await generate_key_helper_fn(**data_json)
|
||||||
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
||||||
|
|
||||||
|
@router.post("/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
|
async def update_key_fn(request: Request, data: UpdateKeyRequest):
|
||||||
|
"""
|
||||||
|
Update an existing key
|
||||||
|
"""
|
||||||
|
global prisma_client
|
||||||
|
try:
|
||||||
|
data_json: dict = data.json()
|
||||||
|
key = data_json.pop("key")
|
||||||
|
# get the row from db
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception("Not connected to DB!")
|
||||||
|
|
||||||
|
non_default_values = {k: v for k, v in data_json.items() if v is not None}
|
||||||
|
print(f"non_default_values: {non_default_values}")
|
||||||
|
response = await prisma_client.update_data(token=key, data={**non_default_values, "token": key})
|
||||||
|
return {"key": key, **non_default_values}
|
||||||
|
# update based on remaining passed in values
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={"error": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -71,6 +71,38 @@ def test_add_new_key(client):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
|
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_new_key(client):
|
||||||
|
try:
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"],
|
||||||
|
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||||
|
"duration": "20m"
|
||||||
|
}
|
||||||
|
print("testing proxy server")
|
||||||
|
# Your bearer token
|
||||||
|
token = os.getenv("PROXY_MASTER_KEY")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {token}"
|
||||||
|
}
|
||||||
|
response = client.post("/key/generate", json=test_data, headers=headers)
|
||||||
|
print(f"response: {response.text}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
result = response.json()
|
||||||
|
assert result["key"].startswith("sk-")
|
||||||
|
def _post_data():
|
||||||
|
json_data = {'models': ['bedrock-models'], "key": result["key"]}
|
||||||
|
response = client.post("/key/update", json=json_data, headers=headers)
|
||||||
|
print(f"response text: {response.text}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
return response
|
||||||
|
_post_data()
|
||||||
|
print(f"Received response: {result}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
|
||||||
|
|
||||||
# # Run the test - only runs via pytest
|
# # Run the test - only runs via pytest
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue