fix(proxy_server.py): fix key gen error

This commit is contained in:
Krrish Dholakia 2023-12-09 22:04:59 -08:00
parent f10bb708c0
commit 07b4c72a98
3 changed files with 31 additions and 4 deletions

View file

@ -1,4 +1,4 @@
from pydantic import BaseModel, Extra, Field from pydantic import BaseModel, Extra, Field, root_validator
from typing import Optional, List, Union, Dict, Literal from typing import Optional, List, Union, Dict, Literal
from datetime import datetime from datetime import datetime
import uuid, json import uuid, json
@ -65,15 +65,36 @@ class ModelInfo(BaseModel):
class Config: class Config:
extra = Extra.allow # Allow extra fields extra = Extra.allow # Allow extra fields
protected_namespaces = () protected_namespaces = ()
# @root_validator(pre=True)
# def set_model_info(cls, values):
# if values.get("id") is None:
# values.update({"id": str(uuid.uuid4())})
# if values.get("mode") is None:
# values.update({"mode": str(uuid.uuid4())})
# return values
class ModelParams(BaseModel): class ModelParams(BaseModel):
model_name: str model_name: str
litellm_params: dict litellm_params: dict
model_info: Optional[ModelInfo]=None model_info: Optional[ModelInfo]=None
# def __init__(self, model_name: str, litellm_params: dict, model_info: Optional[ModelInfo] = None):
# self.model_name = model_name
# self.litellm_params = litellm_params
# self.model_info = model_info if model_info else ModelInfo()
# super.__init__(model_name=self.model_name, litellm_params=self.litellm_params, model_info=self.model_info)
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
# @root_validator(pre=True)
# def set_model_info(cls, values):
# if values.get("model_info") is None:
# values.update({"model_info": ModelInfo()})
# return values
class GenerateKeyRequest(BaseModel): class GenerateKeyRequest(BaseModel):
duration: Optional[str] = "1h" duration: Optional[str] = "1h"

View file

@ -294,7 +294,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
api_key = valid_token.token api_key = valid_token.token
valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None) valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else: else:
raise Exception(f"Invalid token") raise Exception(f"Invalid token")
except Exception as e: except Exception as e:
@ -1224,6 +1224,7 @@ async def add_new_model(model_params: ModelParams):
return {"message": "Model added successfully"} return {"message": "Model added successfully"}
except Exception as e: except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info #### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info

View file

@ -44,7 +44,7 @@ def test_add_new_key(client):
try: try:
# Your test data # Your test data
test_data = { test_data = {
"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], "models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"}, "aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": "20m" "duration": "20m"
} }
@ -60,6 +60,11 @@ def test_add_new_key(client):
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
assert result["key"].startswith("sk-") assert result["key"].startswith("sk-")
def _post_data():
json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]}
response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"})
return response
_post_data()
print(f"Received response: {result}") print(f"Received response: {result}")
except Exception as e: except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}") pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")