mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(main.py): fix key leak error when unknown provider given (#8556)
* fix(main.py): fix key leak error when unknown provider given don't return passed in args if unknown route on embedding * fix(main.py): remove instances of {args} being passed in exception prevent potential key leaks * test(code_coverage/prevent_key_leaks_in_codebase.py): ban usage of {args} in codebase * fix: fix linting errors * fix: remove unused variable
This commit is contained in:
parent
c6026ea6f9
commit
a9276f27f9
8 changed files with 193 additions and 30 deletions
|
@ -1058,6 +1058,7 @@ jobs:
|
||||||
- run: python ./tests/code_coverage_tests/ensure_async_clients_test.py
|
- run: python ./tests/code_coverage_tests/ensure_async_clients_test.py
|
||||||
- run: python ./tests/code_coverage_tests/enforce_llms_folder_style.py
|
- run: python ./tests/code_coverage_tests/enforce_llms_folder_style.py
|
||||||
- run: python ./tests/documentation_tests/test_circular_imports.py
|
- run: python ./tests/documentation_tests/test_circular_imports.py
|
||||||
|
- run: python ./tests/code_coverage_tests/prevent_key_leaks_in_exceptions.py
|
||||||
- run: helm lint ./deploy/charts/litellm-helm
|
- run: helm lint ./deploy/charts/litellm-helm
|
||||||
|
|
||||||
db_migration_disable_update_check:
|
db_migration_disable_update_check:
|
||||||
|
|
|
@ -14,6 +14,8 @@ from typing import Optional
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
from litellm.types.utils import LiteLLMCommonStrings
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationError(openai.AuthenticationError): # type: ignore
|
class AuthenticationError(openai.AuthenticationError): # type: ignore
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -790,3 +792,16 @@ class MockException(openai.APIError):
|
||||||
if request is None:
|
if request is None:
|
||||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||||
super().__init__(self.message, request=request, body=None) # type: ignore
|
super().__init__(self.message, request=request, body=None) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMUnknownProvider(BadRequestError):
|
||||||
|
def __init__(self, model: str, custom_llm_provider: Optional[str] = None):
|
||||||
|
self.message = LiteLLMCommonStrings.llm_provider_not_provided.value.format(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
self.message, model=model, llm_provider=custom_llm_provider, response=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.message
|
||||||
|
|
|
@ -50,6 +50,7 @@ from litellm import ( # type: ignore
|
||||||
get_litellm_params,
|
get_litellm_params,
|
||||||
get_optional_params,
|
get_optional_params,
|
||||||
)
|
)
|
||||||
|
from litellm.exceptions import LiteLLMUnknownProvider
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_for_health_check
|
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_for_health_check
|
||||||
from litellm.litellm_core_utils.health_check_utils import (
|
from litellm.litellm_core_utils.health_check_utils import (
|
||||||
|
@ -3036,8 +3037,8 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
custom_handler = item["custom_handler"]
|
custom_handler = item["custom_handler"]
|
||||||
|
|
||||||
if custom_handler is None:
|
if custom_handler is None:
|
||||||
raise ValueError(
|
raise LiteLLMUnknownProvider(
|
||||||
f"Unable to map your input to a model. Check your input - {args}"
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
## ROUTE LLM CALL ##
|
## ROUTE LLM CALL ##
|
||||||
|
@ -3075,8 +3076,8 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise LiteLLMUnknownProvider(
|
||||||
f"Unable to map your input to a model. Check your input - {args}"
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -3263,17 +3264,10 @@ def embedding( # noqa: PLR0915
|
||||||
"""
|
"""
|
||||||
azure = kwargs.get("azure", None)
|
azure = kwargs.get("azure", None)
|
||||||
client = kwargs.pop("client", None)
|
client = kwargs.pop("client", None)
|
||||||
rpm = kwargs.pop("rpm", None)
|
|
||||||
tpm = kwargs.pop("tpm", None)
|
|
||||||
max_retries = kwargs.get("max_retries", None)
|
max_retries = kwargs.get("max_retries", None)
|
||||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||||
cooldown_time = kwargs.get("cooldown_time", None)
|
|
||||||
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
|
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
|
||||||
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
|
|
||||||
azure_ad_token_provider = kwargs.pop("azure_ad_token_provider", None)
|
azure_ad_token_provider = kwargs.pop("azure_ad_token_provider", None)
|
||||||
model_info = kwargs.get("model_info", None)
|
|
||||||
metadata = kwargs.get("metadata", None)
|
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
|
||||||
aembedding = kwargs.get("aembedding", None)
|
aembedding = kwargs.get("aembedding", None)
|
||||||
extra_headers = kwargs.get("extra_headers", None)
|
extra_headers = kwargs.get("extra_headers", None)
|
||||||
headers = kwargs.get("headers", None)
|
headers = kwargs.get("headers", None)
|
||||||
|
@ -3366,7 +3360,6 @@ def embedding( # noqa: PLR0915
|
||||||
|
|
||||||
if azure is True or custom_llm_provider == "azure":
|
if azure is True or custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
|
|
||||||
|
|
||||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||||
|
|
||||||
|
@ -3439,7 +3432,6 @@ def embedding( # noqa: PLR0915
|
||||||
if extra_headers is not None:
|
if extra_headers is not None:
|
||||||
optional_params["extra_headers"] = extra_headers
|
optional_params["extra_headers"] = extra_headers
|
||||||
|
|
||||||
api_type = "openai"
|
|
||||||
api_version = None
|
api_version = None
|
||||||
|
|
||||||
## EMBEDDING CALL
|
## EMBEDDING CALL
|
||||||
|
@ -3850,14 +3842,16 @@ def embedding( # noqa: PLR0915
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
args = locals()
|
raise LiteLLMUnknownProvider(
|
||||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
if response is not None and hasattr(response, "_hidden_params"):
|
if response is not None and hasattr(response, "_hidden_params"):
|
||||||
response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
||||||
|
|
||||||
if response is None:
|
if response is None:
|
||||||
args = locals()
|
raise LiteLLMUnknownProvider(
|
||||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -4667,8 +4661,8 @@ def image_generation( # noqa: PLR0915
|
||||||
custom_handler = item["custom_handler"]
|
custom_handler = item["custom_handler"]
|
||||||
|
|
||||||
if custom_handler is None:
|
if custom_handler is None:
|
||||||
raise ValueError(
|
raise LiteLLMUnknownProvider(
|
||||||
f"Unable to map your input to a model. Check your input - {args}"
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
## ROUTE LLM CALL ##
|
## ROUTE LLM CALL ##
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: azure-gpt-35-turbo
|
- model_name: azure-gpt-35-turbo
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: topaz/chatgpt-v-2
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
api_base: os.environ/AZURE_API_BASE
|
|
||||||
timeout: 0.000000001
|
|
|
@ -3057,7 +3057,7 @@ class Router:
|
||||||
|
|
||||||
if hasattr(original_exception, "message"):
|
if hasattr(original_exception, "message"):
|
||||||
# add the available fallbacks to the exception
|
# add the available fallbacks to the exception
|
||||||
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
|
original_exception.message += ". Received Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
|
||||||
model_group,
|
model_group,
|
||||||
fallback_model_group,
|
fallback_model_group,
|
||||||
)
|
)
|
||||||
|
@ -3122,9 +3122,7 @@ class Router:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915
|
async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug("Inside async function with retries.")
|
||||||
f"Inside async function with retries: args - {args}; kwargs - {kwargs}"
|
|
||||||
)
|
|
||||||
original_function = kwargs.pop("original_function")
|
original_function = kwargs.pop("original_function")
|
||||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||||
|
|
|
@ -59,6 +59,7 @@ class LiteLLMPydanticObjectBase(BaseModel):
|
||||||
|
|
||||||
class LiteLLMCommonStrings(Enum):
|
class LiteLLMCommonStrings(Enum):
|
||||||
redacted_by_litellm = "redacted by litellm. 'litellm.turn_off_message_logging=True'"
|
redacted_by_litellm = "redacted by litellm. 'litellm.turn_off_message_logging=True'"
|
||||||
|
llm_provider_not_provided = "Unmapped LLM provider for this endpoint. You passed model={model}, custom_llm_provider={custom_llm_provider}. Check supported provider and route: https://docs.litellm.ai/docs/providers"
|
||||||
|
|
||||||
|
|
||||||
SupportedCacheControls = ["ttl", "s-maxage", "no-cache", "no-store"]
|
SupportedCacheControls = ["ttl", "s-maxage", "no-cache", "no-store"]
|
||||||
|
|
|
@ -618,7 +618,7 @@ def function_setup( # noqa: PLR0915
|
||||||
details_to_log.pop("prompt", None)
|
details_to_log.pop("prompt", None)
|
||||||
add_breadcrumb(
|
add_breadcrumb(
|
||||||
category="litellm.llm_call",
|
category="litellm.llm_call",
|
||||||
message=f"Positional Args: {args}, Keyword Args: {details_to_log}",
|
message=f"Keyword Args: {details_to_log}",
|
||||||
level="info",
|
level="info",
|
||||||
)
|
)
|
||||||
if "logger_fn" in kwargs:
|
if "logger_fn" in kwargs:
|
||||||
|
@ -726,8 +726,8 @@ def function_setup( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
return logging_obj, kwargs
|
return logging_obj, kwargs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.error(
|
verbose_logger.exception(
|
||||||
f"litellm.utils.py::function_setup() - [Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}"
|
"litellm.utils.py::function_setup() - [Non-Blocking] Error in function_setup"
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
156
tests/code_coverage_tests/prevent_key_leaks_in_exceptions.py
Normal file
156
tests/code_coverage_tests/prevent_key_leaks_in_exceptions.py
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def is_venv_directory(path):
|
||||||
|
"""
|
||||||
|
Check if the path contains virtual environment directories.
|
||||||
|
Common virtual environment directory names: venv, env, .env, myenv, .venv
|
||||||
|
"""
|
||||||
|
venv_indicators = [
|
||||||
|
"venv",
|
||||||
|
"env",
|
||||||
|
".env",
|
||||||
|
"myenv",
|
||||||
|
".venv",
|
||||||
|
"virtualenv",
|
||||||
|
"site-packages",
|
||||||
|
]
|
||||||
|
|
||||||
|
path_parts = path.lower().split(os.sep)
|
||||||
|
return any(indicator in path_parts for indicator in venv_indicators)
|
||||||
|
|
||||||
|
|
||||||
|
class ArgsStringVisitor(ast.NodeVisitor):
|
||||||
|
"""
|
||||||
|
AST visitor that finds all instances of '{args}' string usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.args_locations = []
|
||||||
|
self.current_file = None
|
||||||
|
|
||||||
|
def set_file(self, filename):
|
||||||
|
self.current_file = filename
|
||||||
|
self.args_locations = []
|
||||||
|
|
||||||
|
def visit_Str(self, node):
|
||||||
|
"""Check string literals for {args}"""
|
||||||
|
if "{args}" in node.s:
|
||||||
|
self.args_locations.append(
|
||||||
|
{
|
||||||
|
"line": node.lineno,
|
||||||
|
"col": node.col_offset,
|
||||||
|
"text": node.s,
|
||||||
|
"file": self.current_file,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit_JoinedStr(self, node):
|
||||||
|
"""Check f-strings for {args}"""
|
||||||
|
for value in node.values:
|
||||||
|
if isinstance(value, ast.FormattedValue):
|
||||||
|
# Check if the formatted value uses 'args'
|
||||||
|
if isinstance(value.value, ast.Name) and value.value.id == "args":
|
||||||
|
self.args_locations.append(
|
||||||
|
{
|
||||||
|
"line": node.lineno,
|
||||||
|
"col": node.col_offset,
|
||||||
|
"text": "f-string with {args}",
|
||||||
|
"file": self.current_file,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_file_for_args_string(file_path):
|
||||||
|
"""
|
||||||
|
Analyzes a Python file for any usage of '{args}'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): Path to the Python file to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of dictionaries containing information about {args} usage
|
||||||
|
"""
|
||||||
|
with open(file_path, "r", encoding="utf-8") as file:
|
||||||
|
try:
|
||||||
|
content = file.read()
|
||||||
|
tree = ast.parse(content)
|
||||||
|
|
||||||
|
# First check using AST for more accurate detection in strings
|
||||||
|
visitor = ArgsStringVisitor()
|
||||||
|
visitor.set_file(file_path)
|
||||||
|
visitor.visit(tree)
|
||||||
|
ast_locations = visitor.args_locations
|
||||||
|
|
||||||
|
# Also check using regex for any instances we might have missed
|
||||||
|
# (like in comments or docstrings)
|
||||||
|
line_number = 1
|
||||||
|
additional_locations = []
|
||||||
|
|
||||||
|
for line in content.split("\n"):
|
||||||
|
if "{args}" in line:
|
||||||
|
# Only add if it's not already caught by the AST visitor
|
||||||
|
if not any(loc["line"] == line_number for loc in ast_locations):
|
||||||
|
additional_locations.append(
|
||||||
|
{
|
||||||
|
"line": line_number,
|
||||||
|
"col": line.index("{args}"),
|
||||||
|
"text": line.strip(),
|
||||||
|
"file": file_path,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
line_number += 1
|
||||||
|
|
||||||
|
return ast_locations + additional_locations
|
||||||
|
|
||||||
|
except SyntaxError as e:
|
||||||
|
print(f"Syntax error in {file_path}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def check_directory_for_args_string(directory_path):
|
||||||
|
"""
|
||||||
|
Recursively checks all Python files in a directory for '{args}' usage,
|
||||||
|
excluding virtual environment directories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory_path (str): Path to the directory to check
|
||||||
|
"""
|
||||||
|
all_violations = []
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(directory_path):
|
||||||
|
# Skip virtual environment directories
|
||||||
|
if is_venv_directory(root):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".py"):
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
violations = check_file_for_args_string(file_path)
|
||||||
|
all_violations.extend(violations)
|
||||||
|
|
||||||
|
return all_violations
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Update this path to point to your codebase root directory
|
||||||
|
# codebase_path = "../../litellm" # Adjust as needed
|
||||||
|
codebase_path = "./litellm"
|
||||||
|
|
||||||
|
violations = check_directory_for_args_string(codebase_path)
|
||||||
|
|
||||||
|
if violations:
|
||||||
|
print("Found '{args}' usage in the following locations:")
|
||||||
|
for violation in violations:
|
||||||
|
print(f"- {violation['file']}:{violation['line']} - {violation['text']}")
|
||||||
|
raise Exception(
|
||||||
|
f"Found {len(violations)} instances of '{{args}}' usage in the codebase"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("No '{args}' usage found in the codebase.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Add table
Add a link
Reference in a new issue