From ed50522863bf1001000fffc1237dd1e50108393f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 9 Dec 2023 12:09:49 -0800 Subject: [PATCH] fix(proxy_server.py): fix pydantic version errors --- .pre-commit-config.yaml | 2 +- litellm/proxy/_types.py | 8 ++++++-- litellm/proxy/custom_callbacks.py | 28 ++++++++++++++++------------ litellm/proxy/proxy_cli.py | 8 ++++---- litellm/proxy/proxy_server.py | 11 +++++++++-- 5 files changed, 36 insertions(+), 21 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9fc7bb3d3..9352959fe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: rev: 3.8.4 # The version of flake8 to use hooks: - id: flake8 - exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/integrations/ + exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/ additional_dependencies: [flake8-print] files: litellm/.*\.py - repo: local diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index b944df3a3..628ea2379 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -84,8 +84,12 @@ class GenerateKeyRequest(BaseModel): user_id: Optional[str] = None max_parallel_requests: Optional[int] = None - def json(self, **kwargs) -> str: - return json.dumps(self.dict(), **kwargs) + def json(self, **kwargs): + try: + return self.model_dump() # noqa + except: + # if using pydantic v1 + return json.dumps(self.dict(), **kwargs) class GenerateKeyResponse(BaseModel): key: str diff --git a/litellm/proxy/custom_callbacks.py b/litellm/proxy/custom_callbacks.py index 08947a066..c04916344 100644 --- a/litellm/proxy/custom_callbacks.py +++ b/litellm/proxy/custom_callbacks.py @@ -4,44 +4,48 @@ import inspect # This file includes the custom callbacks for LiteLLM Proxy # Once defined, these can be passed in proxy_config.yaml +def print_verbose(print_statement): + if litellm.set_verbose: + print(print_statement) # noqa + class MyCustomHandler(CustomLogger): def __init__(self): blue_color_code = "\033[94m" reset_color_code = "\033[0m" - print(f"{blue_color_code}Initialized LiteLLM custom logger") + print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger") try: - print(f"Logger Initialized with following methods:") + print_verbose(f"Logger Initialized with following methods:") methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))] - # Pretty print the methods + # Pretty print_verbose the methods for method in methods: - print(f" - {method}") - print(f"{reset_color_code}") + print_verbose(f" - {method}") + print_verbose(f"{reset_color_code}") except: pass def log_pre_api_call(self, model, messages, kwargs): - print(f"Pre-API Call") + print_verbose(f"Pre-API Call") def log_post_api_call(self, kwargs, response_obj, start_time, end_time): - print(f"Post-API Call") + print_verbose(f"Post-API Call") def log_stream_event(self, kwargs, response_obj, start_time, end_time): - print(f"On Stream") + print_verbose(f"On Stream") def log_success_event(self, kwargs, response_obj, start_time, end_time): - print("On Success!") + print_verbose("On Success!") async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - print(f"On Async Success!") + print_verbose(f"On Async Success!") return async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: - print(f"On Async Failure !") + print_verbose(f"On Async Failure !") except Exception as e: - print(f"Exception: {e}") + print_verbose(f"Exception: {e}") proxy_handler_instance = MyCustomHandler() diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 7dca11dd4..57908e59a 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -26,7 +26,7 @@ def run_ollama_serve(): except Exception as e: print(f""" LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` - """) + """) # noqa def clone_subfolder(repo_url, subfolder, destination): # Clone the full repo @@ -109,9 +109,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers # get n recent logs recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]} - print(json.dumps(recent_logs, indent=4)) + print(json.dumps(recent_logs, indent=4)) # noqa except: - print("LiteLLM: No logs saved!") + raise Exception("LiteLLM: No logs saved!") return if model and "ollama" in model: run_ollama_serve() @@ -140,7 +140,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers if status == "finished": llm_response = polling_response["result"] break - print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") + print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") # noqa time.sleep(0.5) except Exception as e: print("got exception in polling", e) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b4bcbedb6..c15f7ae54 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -227,6 +227,13 @@ def _get_bearer_token(api_key: str): api_key = api_key.replace("Bearer ", "") # extract the token return api_key +def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: + try: + return pydantic_obj.model_dump() # noqa + except: + # if using pydantic v1 + return pydantic_obj.dict() + async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth: global master_key, prisma_client, llm_model_list, user_custom_auth try: @@ -275,7 +282,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap print("\n new llm router model list", llm_model_list) if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called api_key = valid_token.token - valid_token_dict = valid_token.model_dump() + valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict.pop("token", None) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) else: @@ -286,7 +293,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap if model and model not in valid_token.models: raise Exception(f"Token not allowed to access model") api_key = valid_token.token - valid_token_dict = valid_token.model_dump() + valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token.pop("token", None) return UserAPIKeyAuth(api_key=api_key, **valid_token) else: