fix(proxy_server.py): fix pydantic version errors

This commit is contained in:
Krrish Dholakia 2023-12-09 12:09:49 -08:00
parent 0294e1119e
commit ed50522863
5 changed files with 36 additions and 21 deletions

View file

@ -3,7 +3,7 @@ repos:
rev: 3.8.4 # The version of flake8 to use rev: 3.8.4 # The version of flake8 to use
hooks: hooks:
- id: flake8 - id: flake8
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/integrations/ exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/
additional_dependencies: [flake8-print] additional_dependencies: [flake8-print]
files: litellm/.*\.py files: litellm/.*\.py
- repo: local - repo: local

View file

@ -84,7 +84,11 @@ class GenerateKeyRequest(BaseModel):
user_id: Optional[str] = None user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
def json(self, **kwargs) -> str: def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return json.dumps(self.dict(), **kwargs) return json.dumps(self.dict(), **kwargs)
class GenerateKeyResponse(BaseModel): class GenerateKeyResponse(BaseModel):

View file

@ -4,44 +4,48 @@ import inspect
# This file includes the custom callbacks for LiteLLM Proxy # This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml # Once defined, these can be passed in proxy_config.yaml
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
class MyCustomHandler(CustomLogger): class MyCustomHandler(CustomLogger):
def __init__(self): def __init__(self):
blue_color_code = "\033[94m" blue_color_code = "\033[94m"
reset_color_code = "\033[0m" reset_color_code = "\033[0m"
print(f"{blue_color_code}Initialized LiteLLM custom logger") print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger")
try: try:
print(f"Logger Initialized with following methods:") print_verbose(f"Logger Initialized with following methods:")
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))] methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
# Pretty print the methods # Pretty print_verbose the methods
for method in methods: for method in methods:
print(f" - {method}") print_verbose(f" - {method}")
print(f"{reset_color_code}") print_verbose(f"{reset_color_code}")
except: except:
pass pass
def log_pre_api_call(self, model, messages, kwargs): def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call") print_verbose(f"Pre-API Call")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time): def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call") print_verbose(f"Post-API Call")
def log_stream_event(self, kwargs, response_obj, start_time, end_time): def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream") print_verbose(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
print("On Success!") print_verbose("On Success!")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Success!") print_verbose(f"On Async Success!")
return return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try: try:
print(f"On Async Failure !") print_verbose(f"On Async Failure !")
except Exception as e: except Exception as e:
print(f"Exception: {e}") print_verbose(f"Exception: {e}")
proxy_handler_instance = MyCustomHandler() proxy_handler_instance = MyCustomHandler()

View file

@ -26,7 +26,7 @@ def run_ollama_serve():
except Exception as e: except Exception as e:
print(f""" print(f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""") """) # noqa
def clone_subfolder(repo_url, subfolder, destination): def clone_subfolder(repo_url, subfolder, destination):
# Clone the full repo # Clone the full repo
@ -109,9 +109,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
# get n recent logs # get n recent logs
recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]} recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]}
print(json.dumps(recent_logs, indent=4)) print(json.dumps(recent_logs, indent=4)) # noqa
except: except:
print("LiteLLM: No logs saved!") raise Exception("LiteLLM: No logs saved!")
return return
if model and "ollama" in model: if model and "ollama" in model:
run_ollama_serve() run_ollama_serve()
@ -140,7 +140,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
if status == "finished": if status == "finished":
llm_response = polling_response["result"] llm_response = polling_response["result"]
break break
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") # noqa
time.sleep(0.5) time.sleep(0.5)
except Exception as e: except Exception as e:
print("got exception in polling", e) print("got exception in polling", e)

View file

@ -227,6 +227,13 @@ def _get_bearer_token(api_key: str):
api_key = api_key.replace("Bearer ", "") # extract the token api_key = api_key.replace("Bearer ", "") # extract the token
return api_key return api_key
def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
try:
return pydantic_obj.model_dump() # noqa
except:
# if using pydantic v1
return pydantic_obj.dict()
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth: async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth:
global master_key, prisma_client, llm_model_list, user_custom_auth global master_key, prisma_client, llm_model_list, user_custom_auth
try: try:
@ -275,7 +282,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
print("\n new llm router model list", llm_model_list) print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
api_key = valid_token.token api_key = valid_token.token
valid_token_dict = valid_token.model_dump() valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None) valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else: else:
@ -286,7 +293,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if model and model not in valid_token.models: if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model") raise Exception(f"Token not allowed to access model")
api_key = valid_token.token api_key = valid_token.token
valid_token_dict = valid_token.model_dump() valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token.pop("token", None) valid_token.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token) return UserAPIKeyAuth(api_key=api_key, **valid_token)
else: else: