fix(proxy_server.py): fix key gen error

This commit is contained in:
Krrish Dholakia 2023-12-09 22:04:59 -08:00
parent f10bb708c0
commit 07b4c72a98
3 changed files with 31 additions and 4 deletions

View file

@ -1,4 +1,4 @@
from pydantic import BaseModel, Extra, Field
from pydantic import BaseModel, Extra, Field, root_validator
from typing import Optional, List, Union, Dict, Literal
from datetime import datetime
import uuid, json
@ -66,15 +66,36 @@ class ModelInfo(BaseModel):
extra = Extra.allow # Allow extra fields
protected_namespaces = ()
# @root_validator(pre=True)
# def set_model_info(cls, values):
# if values.get("id") is None:
# values.update({"id": str(uuid.uuid4())})
# if values.get("mode") is None:
# values.update({"mode": str(uuid.uuid4())})
# return values
class ModelParams(BaseModel):
model_name: str
litellm_params: dict
model_info: Optional[ModelInfo]=None
# def __init__(self, model_name: str, litellm_params: dict, model_info: Optional[ModelInfo] = None):
# self.model_name = model_name
# self.litellm_params = litellm_params
# self.model_info = model_info if model_info else ModelInfo()
# super.__init__(model_name=self.model_name, litellm_params=self.litellm_params, model_info=self.model_info)
class Config:
protected_namespaces = ()
# @root_validator(pre=True)
# def set_model_info(cls, values):
# if values.get("model_info") is None:
# values.update({"model_info": ModelInfo()})
# return values
class GenerateKeyRequest(BaseModel):
duration: Optional[str] = "1h"
models: Optional[list] = []

View file

@ -294,7 +294,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
api_key = valid_token.token
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
raise Exception(f"Invalid token")
except Exception as e:
@ -1224,6 +1224,7 @@ async def add_new_model(model_params: ModelParams):
return {"message": "Model added successfully"}
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info

View file

@ -44,7 +44,7 @@ def test_add_new_key(client):
try:
# Your test data
test_data = {
"models": ["gpt-3.5-turbo", "gpt-4", "claude-2"],
"models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": "20m"
}
@ -60,6 +60,11 @@ def test_add_new_key(client):
assert response.status_code == 200
result = response.json()
assert result["key"].startswith("sk-")
def _post_data():
json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]}
response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"})
return response
_post_data()
print(f"Received response: {result}")
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")