diff --git a/.circleci/config.yml b/.circleci/config.yml index f422ab8381..d87f677d82 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1058,6 +1058,7 @@ jobs: - run: python ./tests/code_coverage_tests/ensure_async_clients_test.py - run: python ./tests/code_coverage_tests/enforce_llms_folder_style.py - run: python ./tests/documentation_tests/test_circular_imports.py + - run: python ./tests/code_coverage_tests/prevent_key_leaks_in_exceptions.py - run: helm lint ./deploy/charts/litellm-helm db_migration_disable_update_check: diff --git a/litellm/exceptions.py b/litellm/exceptions.py index c26928a656..f4166a5837 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -14,6 +14,8 @@ from typing import Optional import httpx import openai +from litellm.types.utils import LiteLLMCommonStrings + class AuthenticationError(openai.AuthenticationError): # type: ignore def __init__( @@ -790,3 +792,16 @@ class MockException(openai.APIError): if request is None: request = httpx.Request(method="POST", url="https://api.openai.com/v1") super().__init__(self.message, request=request, body=None) # type: ignore + + +class LiteLLMUnknownProvider(BadRequestError): + def __init__(self, model: str, custom_llm_provider: Optional[str] = None): + self.message = LiteLLMCommonStrings.llm_provider_not_provided.value.format( + model=model, custom_llm_provider=custom_llm_provider + ) + super().__init__( + self.message, model=model, llm_provider=custom_llm_provider, response=None + ) + + def __str__(self): + return self.message diff --git a/litellm/main.py b/litellm/main.py index 14e9f45d1e..8535551646 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -50,6 +50,7 @@ from litellm import ( # type: ignore get_litellm_params, get_optional_params, ) +from litellm.exceptions import LiteLLMUnknownProvider from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_for_health_check from litellm.litellm_core_utils.health_check_utils import ( @@ -3036,8 +3037,8 @@ def completion( # type: ignore # noqa: PLR0915 custom_handler = item["custom_handler"] if custom_handler is None: - raise ValueError( - f"Unable to map your input to a model. Check your input - {args}" + raise LiteLLMUnknownProvider( + model=model, custom_llm_provider=custom_llm_provider ) ## ROUTE LLM CALL ## @@ -3075,8 +3076,8 @@ def completion( # type: ignore # noqa: PLR0915 ) else: - raise ValueError( - f"Unable to map your input to a model. Check your input - {args}" + raise LiteLLMUnknownProvider( + model=model, custom_llm_provider=custom_llm_provider ) return response except Exception as e: @@ -3263,17 +3264,10 @@ def embedding( # noqa: PLR0915 """ azure = kwargs.get("azure", None) client = kwargs.pop("client", None) - rpm = kwargs.pop("rpm", None) - tpm = kwargs.pop("tpm", None) max_retries = kwargs.get("max_retries", None) litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore - cooldown_time = kwargs.get("cooldown_time", None) mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore - max_parallel_requests = kwargs.pop("max_parallel_requests", None) azure_ad_token_provider = kwargs.pop("azure_ad_token_provider", None) - model_info = kwargs.get("model_info", None) - metadata = kwargs.get("metadata", None) - proxy_server_request = kwargs.get("proxy_server_request", None) aembedding = kwargs.get("aembedding", None) extra_headers = kwargs.get("extra_headers", None) headers = kwargs.get("headers", None) @@ -3366,7 +3360,6 @@ def embedding( # noqa: PLR0915 if azure is True or custom_llm_provider == "azure": # azure configs - api_type = get_secret_str("AZURE_API_TYPE") or "azure" api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") @@ -3439,7 +3432,6 @@ def embedding( # noqa: PLR0915 if extra_headers is not None: optional_params["extra_headers"] = extra_headers - api_type = "openai" api_version = None ## EMBEDDING CALL @@ -3850,14 +3842,16 @@ def embedding( # noqa: PLR0915 aembedding=aembedding, ) else: - args = locals() - raise ValueError(f"No valid embedding model args passed in - {args}") + raise LiteLLMUnknownProvider( + model=model, custom_llm_provider=custom_llm_provider + ) if response is not None and hasattr(response, "_hidden_params"): response._hidden_params["custom_llm_provider"] = custom_llm_provider if response is None: - args = locals() - raise ValueError(f"No valid embedding model args passed in - {args}") + raise LiteLLMUnknownProvider( + model=model, custom_llm_provider=custom_llm_provider + ) return response except Exception as e: ## LOGGING @@ -4667,8 +4661,8 @@ def image_generation( # noqa: PLR0915 custom_handler = item["custom_handler"] if custom_handler is None: - raise ValueError( - f"Unable to map your input to a model. Check your input - {args}" + raise LiteLLMUnknownProvider( + model=model, custom_llm_provider=custom_llm_provider ) ## ROUTE LLM CALL ## diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 675dd5bba6..b59a0a57b5 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,7 +1,5 @@ model_list: - model_name: azure-gpt-35-turbo litellm_params: - model: azure/chatgpt-v-2 - api_key: os.environ/AZURE_API_KEY - api_base: os.environ/AZURE_API_BASE - timeout: 0.000000001 \ No newline at end of file + model: topaz/chatgpt-v-2 + api_key: os.environ/AZURE_API_KEY \ No newline at end of file diff --git a/litellm/router.py b/litellm/router.py index 758a94464e..c296a9e398 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -3057,7 +3057,7 @@ class Router: if hasattr(original_exception, "message"): # add the available fallbacks to the exception - original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore + original_exception.message += ". Received Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore model_group, fallback_model_group, ) @@ -3122,9 +3122,7 @@ class Router: ) async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915 - verbose_router_logger.debug( - f"Inside async function with retries: args - {args}; kwargs - {kwargs}" - ) + verbose_router_logger.debug("Inside async function with retries.") original_function = kwargs.pop("original_function") fallbacks = kwargs.pop("fallbacks", self.fallbacks) parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) diff --git a/litellm/types/utils.py b/litellm/types/utils.py index d884b98448..822f55e6fa 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -59,6 +59,7 @@ class LiteLLMPydanticObjectBase(BaseModel): class LiteLLMCommonStrings(Enum): redacted_by_litellm = "redacted by litellm. 'litellm.turn_off_message_logging=True'" + llm_provider_not_provided = "Unmapped LLM provider for this endpoint. You passed model={model}, custom_llm_provider={custom_llm_provider}. Check supported provider and route: https://docs.litellm.ai/docs/providers" SupportedCacheControls = ["ttl", "s-maxage", "no-cache", "no-store"] diff --git a/litellm/utils.py b/litellm/utils.py index e795654420..e6ce54b1b2 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -618,7 +618,7 @@ def function_setup( # noqa: PLR0915 details_to_log.pop("prompt", None) add_breadcrumb( category="litellm.llm_call", - message=f"Positional Args: {args}, Keyword Args: {details_to_log}", + message=f"Keyword Args: {details_to_log}", level="info", ) if "logger_fn" in kwargs: @@ -726,8 +726,8 @@ def function_setup( # noqa: PLR0915 ) return logging_obj, kwargs except Exception as e: - verbose_logger.error( - f"litellm.utils.py::function_setup() - [Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}" + verbose_logger.exception( + "litellm.utils.py::function_setup() - [Non-Blocking] Error in function_setup" ) raise e diff --git a/tests/code_coverage_tests/prevent_key_leaks_in_exceptions.py b/tests/code_coverage_tests/prevent_key_leaks_in_exceptions.py new file mode 100644 index 0000000000..e9cba9bb18 --- /dev/null +++ b/tests/code_coverage_tests/prevent_key_leaks_in_exceptions.py @@ -0,0 +1,156 @@ +import ast +import os +import re + + +def is_venv_directory(path): + """ + Check if the path contains virtual environment directories. + Common virtual environment directory names: venv, env, .env, myenv, .venv + """ + venv_indicators = [ + "venv", + "env", + ".env", + "myenv", + ".venv", + "virtualenv", + "site-packages", + ] + + path_parts = path.lower().split(os.sep) + return any(indicator in path_parts for indicator in venv_indicators) + + +class ArgsStringVisitor(ast.NodeVisitor): + """ + AST visitor that finds all instances of '{args}' string usage. + """ + + def __init__(self): + self.args_locations = [] + self.current_file = None + + def set_file(self, filename): + self.current_file = filename + self.args_locations = [] + + def visit_Str(self, node): + """Check string literals for {args}""" + if "{args}" in node.s: + self.args_locations.append( + { + "line": node.lineno, + "col": node.col_offset, + "text": node.s, + "file": self.current_file, + } + ) + + def visit_JoinedStr(self, node): + """Check f-strings for {args}""" + for value in node.values: + if isinstance(value, ast.FormattedValue): + # Check if the formatted value uses 'args' + if isinstance(value.value, ast.Name) and value.value.id == "args": + self.args_locations.append( + { + "line": node.lineno, + "col": node.col_offset, + "text": "f-string with {args}", + "file": self.current_file, + } + ) + + +def check_file_for_args_string(file_path): + """ + Analyzes a Python file for any usage of '{args}'. + + Args: + file_path (str): Path to the Python file to analyze + + Returns: + list: List of dictionaries containing information about {args} usage + """ + with open(file_path, "r", encoding="utf-8") as file: + try: + content = file.read() + tree = ast.parse(content) + + # First check using AST for more accurate detection in strings + visitor = ArgsStringVisitor() + visitor.set_file(file_path) + visitor.visit(tree) + ast_locations = visitor.args_locations + + # Also check using regex for any instances we might have missed + # (like in comments or docstrings) + line_number = 1 + additional_locations = [] + + for line in content.split("\n"): + if "{args}" in line: + # Only add if it's not already caught by the AST visitor + if not any(loc["line"] == line_number for loc in ast_locations): + additional_locations.append( + { + "line": line_number, + "col": line.index("{args}"), + "text": line.strip(), + "file": file_path, + } + ) + line_number += 1 + + return ast_locations + additional_locations + + except SyntaxError as e: + print(f"Syntax error in {file_path}: {e}") + return [] + + +def check_directory_for_args_string(directory_path): + """ + Recursively checks all Python files in a directory for '{args}' usage, + excluding virtual environment directories. + + Args: + directory_path (str): Path to the directory to check + """ + all_violations = [] + + for root, dirs, files in os.walk(directory_path): + # Skip virtual environment directories + if is_venv_directory(root): + continue + + for file in files: + if file.endswith(".py"): + file_path = os.path.join(root, file) + violations = check_file_for_args_string(file_path) + all_violations.extend(violations) + + return all_violations + + +def main(): + # Update this path to point to your codebase root directory + # codebase_path = "../../litellm" # Adjust as needed + codebase_path = "./litellm" + + violations = check_directory_for_args_string(codebase_path) + + if violations: + print("Found '{args}' usage in the following locations:") + for violation in violations: + print(f"- {violation['file']}:{violation['line']} - {violation['text']}") + raise Exception( + f"Found {len(violations)} instances of '{{args}}' usage in the codebase" + ) + else: + print("No '{args}' usage found in the codebase.") + + +if __name__ == "__main__": + main()