diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 38bb894dd..68709f34d 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, Extra, Field, root_validator from typing import Optional, List, Union, Dict, Literal from datetime import datetime import uuid, json @@ -65,15 +65,36 @@ class ModelInfo(BaseModel): class Config: extra = Extra.allow # Allow extra fields protected_namespaces = () + + # @root_validator(pre=True) + # def set_model_info(cls, values): + # if values.get("id") is None: + # values.update({"id": str(uuid.uuid4())}) + # if values.get("mode") is None: + # values.update({"mode": str(uuid.uuid4())}) + # return values + class ModelParams(BaseModel): model_name: str litellm_params: dict model_info: Optional[ModelInfo]=None - + + # def __init__(self, model_name: str, litellm_params: dict, model_info: Optional[ModelInfo] = None): + # self.model_name = model_name + # self.litellm_params = litellm_params + # self.model_info = model_info if model_info else ModelInfo() + # super.__init__(model_name=self.model_name, litellm_params=self.litellm_params, model_info=self.model_info) + class Config: protected_namespaces = () + + # @root_validator(pre=True) + # def set_model_info(cls, values): + # if values.get("model_info") is None: + # values.update({"model_info": ModelInfo()}) + # return values class GenerateKeyRequest(BaseModel): duration: Optional[str] = "1h" diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5037a3719..a82723051 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -294,7 +294,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap api_key = valid_token.token valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict.pop("token", None) - return UserAPIKeyAuth(api_key=api_key, **valid_token) + return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) else: raise Exception(f"Invalid token") except Exception as e: @@ -1224,6 +1224,7 @@ async def add_new_model(model_params: ModelParams): return {"message": "Model added successfully"} except Exception as e: + traceback.print_exc() raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") #### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info diff --git a/litellm/tests/test_proxy_server_keys.py b/litellm/tests/test_proxy_server_keys.py index 239442b2c..db083c30c 100644 --- a/litellm/tests/test_proxy_server_keys.py +++ b/litellm/tests/test_proxy_server_keys.py @@ -44,7 +44,7 @@ def test_add_new_key(client): try: # Your test data test_data = { - "models": ["gpt-3.5-turbo", "gpt-4", "claude-2"], + "models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"], "aliases": {"mistral-7b": "gpt-3.5-turbo"}, "duration": "20m" } @@ -60,6 +60,11 @@ def test_add_new_key(client): assert response.status_code == 200 result = response.json() assert result["key"].startswith("sk-") + def _post_data(): + json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]} + response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"}) + return response + _post_data() print(f"Received response: {result}") except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")