forked from phoenix/litellm-mirror
fix(proxy_server.py): fix pydantic version errors
This commit is contained in:
parent
0294e1119e
commit
ed50522863
5 changed files with 36 additions and 21 deletions
|
@ -3,7 +3,7 @@ repos:
|
||||||
rev: 3.8.4 # The version of flake8 to use
|
rev: 3.8.4 # The version of flake8 to use
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/integrations/
|
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/proxy/proxy_cli.py|^litellm/integrations/
|
||||||
additional_dependencies: [flake8-print]
|
additional_dependencies: [flake8-print]
|
||||||
files: litellm/.*\.py
|
files: litellm/.*\.py
|
||||||
- repo: local
|
- repo: local
|
||||||
|
|
|
@ -84,8 +84,12 @@ class GenerateKeyRequest(BaseModel):
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
|
|
||||||
def json(self, **kwargs) -> str:
|
def json(self, **kwargs):
|
||||||
return json.dumps(self.dict(), **kwargs)
|
try:
|
||||||
|
return self.model_dump() # noqa
|
||||||
|
except:
|
||||||
|
# if using pydantic v1
|
||||||
|
return json.dumps(self.dict(), **kwargs)
|
||||||
|
|
||||||
class GenerateKeyResponse(BaseModel):
|
class GenerateKeyResponse(BaseModel):
|
||||||
key: str
|
key: str
|
||||||
|
|
|
@ -4,44 +4,48 @@ import inspect
|
||||||
|
|
||||||
# This file includes the custom callbacks for LiteLLM Proxy
|
# This file includes the custom callbacks for LiteLLM Proxy
|
||||||
# Once defined, these can be passed in proxy_config.yaml
|
# Once defined, these can be passed in proxy_config.yaml
|
||||||
|
def print_verbose(print_statement):
|
||||||
|
if litellm.set_verbose:
|
||||||
|
print(print_statement) # noqa
|
||||||
|
|
||||||
class MyCustomHandler(CustomLogger):
|
class MyCustomHandler(CustomLogger):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
blue_color_code = "\033[94m"
|
blue_color_code = "\033[94m"
|
||||||
reset_color_code = "\033[0m"
|
reset_color_code = "\033[0m"
|
||||||
print(f"{blue_color_code}Initialized LiteLLM custom logger")
|
print_verbose(f"{blue_color_code}Initialized LiteLLM custom logger")
|
||||||
try:
|
try:
|
||||||
print(f"Logger Initialized with following methods:")
|
print_verbose(f"Logger Initialized with following methods:")
|
||||||
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
|
methods = [method for method in dir(self) if inspect.ismethod(getattr(self, method))]
|
||||||
|
|
||||||
# Pretty print the methods
|
# Pretty print_verbose the methods
|
||||||
for method in methods:
|
for method in methods:
|
||||||
print(f" - {method}")
|
print_verbose(f" - {method}")
|
||||||
print(f"{reset_color_code}")
|
print_verbose(f"{reset_color_code}")
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def log_pre_api_call(self, model, messages, kwargs):
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
print(f"Pre-API Call")
|
print_verbose(f"Pre-API Call")
|
||||||
|
|
||||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"Post-API Call")
|
print_verbose(f"Post-API Call")
|
||||||
|
|
||||||
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"On Stream")
|
print_verbose(f"On Stream")
|
||||||
|
|
||||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print("On Success!")
|
print_verbose("On Success!")
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"On Async Success!")
|
print_verbose(f"On Async Success!")
|
||||||
return
|
return
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
print(f"On Async Failure !")
|
print_verbose(f"On Async Failure !")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Exception: {e}")
|
print_verbose(f"Exception: {e}")
|
||||||
|
|
||||||
|
|
||||||
proxy_handler_instance = MyCustomHandler()
|
proxy_handler_instance = MyCustomHandler()
|
||||||
|
|
|
@ -26,7 +26,7 @@ def run_ollama_serve():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"""
|
print(f"""
|
||||||
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
|
||||||
""")
|
""") # noqa
|
||||||
|
|
||||||
def clone_subfolder(repo_url, subfolder, destination):
|
def clone_subfolder(repo_url, subfolder, destination):
|
||||||
# Clone the full repo
|
# Clone the full repo
|
||||||
|
@ -109,9 +109,9 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
||||||
# get n recent logs
|
# get n recent logs
|
||||||
recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]}
|
recent_logs = {k.strftime("%Y%m%d%H%M%S%f"): v for k, v in sorted_times[:logs]}
|
||||||
|
|
||||||
print(json.dumps(recent_logs, indent=4))
|
print(json.dumps(recent_logs, indent=4)) # noqa
|
||||||
except:
|
except:
|
||||||
print("LiteLLM: No logs saved!")
|
raise Exception("LiteLLM: No logs saved!")
|
||||||
return
|
return
|
||||||
if model and "ollama" in model:
|
if model and "ollama" in model:
|
||||||
run_ollama_serve()
|
run_ollama_serve()
|
||||||
|
@ -140,7 +140,7 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
||||||
if status == "finished":
|
if status == "finished":
|
||||||
llm_response = polling_response["result"]
|
llm_response = polling_response["result"]
|
||||||
break
|
break
|
||||||
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}")
|
print(f"POLLING JOB{polling_url}\nSTATUS: {status}, \n Response {polling_response}") # noqa
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("got exception in polling", e)
|
print("got exception in polling", e)
|
||||||
|
|
|
@ -227,6 +227,13 @@ def _get_bearer_token(api_key: str):
|
||||||
api_key = api_key.replace("Bearer ", "") # extract the token
|
api_key = api_key.replace("Bearer ", "") # extract the token
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
|
def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
|
||||||
|
try:
|
||||||
|
return pydantic_obj.model_dump() # noqa
|
||||||
|
except:
|
||||||
|
# if using pydantic v1
|
||||||
|
return pydantic_obj.dict()
|
||||||
|
|
||||||
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth:
|
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth:
|
||||||
global master_key, prisma_client, llm_model_list, user_custom_auth
|
global master_key, prisma_client, llm_model_list, user_custom_auth
|
||||||
try:
|
try:
|
||||||
|
@ -275,7 +282,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
print("\n new llm router model list", llm_model_list)
|
print("\n new llm router model list", llm_model_list)
|
||||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||||
api_key = valid_token.token
|
api_key = valid_token.token
|
||||||
valid_token_dict = valid_token.model_dump()
|
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||||
valid_token_dict.pop("token", None)
|
valid_token_dict.pop("token", None)
|
||||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||||
else:
|
else:
|
||||||
|
@ -286,7 +293,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
if model and model not in valid_token.models:
|
if model and model not in valid_token.models:
|
||||||
raise Exception(f"Token not allowed to access model")
|
raise Exception(f"Token not allowed to access model")
|
||||||
api_key = valid_token.token
|
api_key = valid_token.token
|
||||||
valid_token_dict = valid_token.model_dump()
|
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||||
valid_token.pop("token", None)
|
valid_token.pop("token", None)
|
||||||
return UserAPIKeyAuth(api_key=api_key, **valid_token)
|
return UserAPIKeyAuth(api_key=api_key, **valid_token)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue