From 5a4c054eef8beef468cf77481bc8f2deb32c653d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 2 Dec 2023 14:15:38 -0800 Subject: [PATCH] fix(proxy_server.py): add testing for model info being added to /model/new --- .gitignore | 1 + litellm/proxy/proxy_server.py | 19 +++++++++---------- litellm/tests/test_proxy_server.py | 27 +++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index befb64508..088996ddd 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ litellm/proxy/_secret_config.yaml .aws-sam/ litellm/tests/aiologs.log litellm/tests/exception_data.txt +litellm/tests/config_*.yaml diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a197346b9..a852009bb 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -196,7 +196,7 @@ class ProxyChatCompletionRequest(BaseModel): class ModelParams(BaseModel): model_name: str litellm_params: dict - model_info: dict + model_info: Optional[dict] user_api_base = None user_model = None @@ -207,7 +207,7 @@ user_temperature = None user_telemetry = True user_config = None user_headers = None -user_config_file_path = None +user_config_file_path = f"config_{time.time()}.yaml" local_logging = True # writes logs to a local api_log.json file for debugging experimental = False #### GLOBAL VARIABLES #### @@ -606,10 +606,6 @@ async def delete_verification_token(tokens: List[str]): raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) return deleted_tokens -async def generate_key_cli_task(duration_str): - task = asyncio.create_task(generate_key_helper_fn(duration_str=duration_str)) - await task - def save_worker_config(**data): import json os.environ["WORKER_CONFIG"] = json.dumps(data) @@ -1011,13 +1007,16 @@ async def add_new_model(model_params: ModelParams): global llm_router, llm_model_list, general_settings, user_config_file_path try: # Load existing config - with open(f"{user_config_file_path}", "r") as config_file: - config = yaml.safe_load(config_file) - + if os.path.exists(f"{user_config_file_path}"): + with open(f"{user_config_file_path}", "r") as config_file: + config = yaml.safe_load(config_file) + else: + config = {"model_list": []} # Add the new model to the config config['model_list'].append({ 'model_name': model_params.model_name, - 'litellm_params': model_params.litellm_params + 'litellm_params': model_params.litellm_params, + 'model_info': model_params.model_info }) # Save the updated config diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 69cd8bba6..005de2762 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -93,3 +93,30 @@ def test_embedding(): # Run the test # test_embedding() + + +def test_add_new_model(): + try: + test_data = { + "model_name": "test_openai_models", + "litellm_params": { + "model": "gpt-3.5-turbo", + }, + "model_info": { + "description": "this is a test openai model" + } + } + client.post("/model/new", json=test_data) + response = client.get("/model/info") + assert response.status_code == 200 + result = response.json() + print(f"response: {result}") + model_info = None + for m in result["data"]: + if m["id"]["model_name"] == "test_openai_models": + model_info = m["id"]["model_info"] + assert model_info["description"] == "this is a test openai model" + except Exception as e: + pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") + +test_add_new_model() \ No newline at end of file