fix(proxy_server.py): fix pydantic version errors

This commit is contained in:
Krrish Dholakia 2023-12-09 12:09:49 -08:00
parent 0294e1119e
commit ed50522863
5 changed files with 36 additions and 21 deletions

View file

@ -3,7 +3,7 @@ repos:
rev: 3.8.4 # The version of flake8 to use
hooks:
- id: flake8
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/integrations/
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/
additional_dependencies: [flake8-print]
files: litellm/.*\.py
- repo: local

View file

@ -84,8 +84,12 @@ class GenerateKeyRequest(BaseModel):
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
def json(self, **kwargs) -> str:
return json.dumps(self.dict(), **kwargs)
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return json.dumps(self.dict(), **kwargs)
class GenerateKeyResponse(BaseModel):
key: str

View file

@ -4,44 +4,48 @@ import inspect
# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
class MyCustomHandler(CustomLogger):
def __init__(self):
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
print(f"{blue_color_code}Initialized LiteLLM custom logger")
print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger")
try:
print(f"Logger Initialized with following methods:")
print_verbose(f"Logger Initialized with following methods:")
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
# Pretty print the methods
# Pretty print_verbose the methods
for method in methods:
print(f" - {method}")
print(f"{reset_color_code}")
print_verbose(f" - {method}")
print_verbose(f"{reset_color_code}")
except:
pass
def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call")
print_verbose(f"Pre-API Call")
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call")
print_verbose(f"Post-API Call")
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Stream")
print_verbose(f"On Stream")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print("On Success!")
print_verbose("On Success!")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Success!")
print_verbose(f"On Async Success!")
return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
print(f"On Async Failure !")
print_verbose(f"On Async Failure !")
except Exception as e:
print(f"Exception: {e}")
print_verbose(f"Exception: {e}")
proxy_handler_instance = MyCustomHandler()

View file

@ -26,7 +26,7 @@ def run_ollama_serve():
except Exception as e:
print(f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""")
""") # noqa
def clone_subfolder(repo_url, subfolder, destination):
# Clone the full repo
@ -109,9 +109,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
# get n recent logs
recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]}
print(json.dumps(recent_logs, indent=4))
print(json.dumps(recent_logs, indent=4)) # noqa
except:
print("LiteLLM: No logs saved!")
raise Exception("LiteLLM: No logs saved!")
return
if model and "ollama" in model:
run_ollama_serve()
@ -140,7 +140,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
if status == "finished":
llm_response = polling_response["result"]
break
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}")
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") # noqa
time.sleep(0.5)
except Exception as e:
print("got exception in polling", e)

View file

@ -227,6 +227,13 @@ def _get_bearer_token(api_key: str):
api_key = api_key.replace("Bearer ", "") # extract the token
return api_key
def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
try:
return pydantic_obj.model_dump() # noqa
except:
# if using pydantic v1
return pydantic_obj.dict()
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth:
global master_key, prisma_client, llm_model_list, user_custom_auth
try:
@ -275,7 +282,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
api_key = valid_token.token
valid_token_dict = valid_token.model_dump()
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
@ -286,7 +293,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model")
api_key = valid_token.token
valid_token_dict = valid_token.model_dump()
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token)
else: