fix(main.py): fix key leak error when unknown provider given (#8556)

* fix(main.py): fix key leak error when unknown provider given

don't return passed in args if unknown route on embedding

* fix(main.py): remove instances of {args} being passed in exception

prevent potential key leaks

* test(code_coverage/prevent_key_leaks_in_codebase.py): ban usage of {args} in codebase

* fix: fix linting errors

* fix: remove unused variable
This commit is contained in:
Krish Dholakia 2025-02-15 14:02:55 -08:00 committed by GitHub
parent c6026ea6f9
commit a9276f27f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 193 additions and 30 deletions

View file

@ -1058,6 +1058,7 @@ jobs:
- run: python ./tests/code_coverage_tests/ensure_async_clients_test.py
- run: python ./tests/code_coverage_tests/enforce_llms_folder_style.py
- run: python ./tests/documentation_tests/test_circular_imports.py
- run: python ./tests/code_coverage_tests/prevent_key_leaks_in_exceptions.py
- run: helm lint ./deploy/charts/litellm-helm
db_migration_disable_update_check:

View file

@ -14,6 +14,8 @@ from typing import Optional
import httpx
import openai
from litellm.types.utils import LiteLLMCommonStrings
class AuthenticationError(openai.AuthenticationError): # type: ignore
def __init__(
@ -790,3 +792,16 @@ class MockException(openai.APIError):
if request is None:
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
super().__init__(self.message, request=request, body=None) # type: ignore
class LiteLLMUnknownProvider(BadRequestError):
def __init__(self, model: str, custom_llm_provider: Optional[str] = None):
self.message = LiteLLMCommonStrings.llm_provider_not_provided.value.format(
model=model, custom_llm_provider=custom_llm_provider
)
super().__init__(
self.message, model=model, llm_provider=custom_llm_provider, response=None
)
def __str__(self):
return self.message

View file

@ -50,6 +50,7 @@ from litellm import ( # type: ignore
get_litellm_params,
get_optional_params,
)
from litellm.exceptions import LiteLLMUnknownProvider
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_for_health_check
from litellm.litellm_core_utils.health_check_utils import (
@ -3036,8 +3037,8 @@ def completion( # type: ignore # noqa: PLR0915
custom_handler = item["custom_handler"]
if custom_handler is None:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"
raise LiteLLMUnknownProvider(
model=model, custom_llm_provider=custom_llm_provider
)
## ROUTE LLM CALL ##
@ -3075,8 +3076,8 @@ def completion( # type: ignore # noqa: PLR0915
)
else:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"
raise LiteLLMUnknownProvider(
model=model, custom_llm_provider=custom_llm_provider
)
return response
except Exception as e:
@ -3263,17 +3264,10 @@ def embedding( # noqa: PLR0915
"""
azure = kwargs.get("azure", None)
client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None)
max_retries = kwargs.get("max_retries", None)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
cooldown_time = kwargs.get("cooldown_time", None)
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
azure_ad_token_provider = kwargs.pop("azure_ad_token_provider", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.get("aembedding", None)
extra_headers = kwargs.get("extra_headers", None)
headers = kwargs.get("headers", None)
@ -3366,7 +3360,6 @@ def embedding( # noqa: PLR0915
if azure is True or custom_llm_provider == "azure":
# azure configs
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
@ -3439,7 +3432,6 @@ def embedding( # noqa: PLR0915
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
api_type = "openai"
api_version = None
## EMBEDDING CALL
@ -3850,14 +3842,16 @@ def embedding( # noqa: PLR0915
aembedding=aembedding,
)
else:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")
raise LiteLLMUnknownProvider(
model=model, custom_llm_provider=custom_llm_provider
)
if response is not None and hasattr(response, "_hidden_params"):
response._hidden_params["custom_llm_provider"] = custom_llm_provider
if response is None:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")
raise LiteLLMUnknownProvider(
model=model, custom_llm_provider=custom_llm_provider
)
return response
except Exception as e:
## LOGGING
@ -4667,8 +4661,8 @@ def image_generation( # noqa: PLR0915
custom_handler = item["custom_handler"]
if custom_handler is None:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"
raise LiteLLMUnknownProvider(
model=model, custom_llm_provider=custom_llm_provider
)
## ROUTE LLM CALL ##

View file

@ -1,7 +1,5 @@
model_list:
- model_name: azure-gpt-35-turbo
litellm_params:
model: azure/chatgpt-v-2
model: topaz/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
timeout: 0.000000001

View file

@ -3057,7 +3057,7 @@ class Router:
if hasattr(original_exception, "message"):
# add the available fallbacks to the exception
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
original_exception.message += ". Received Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
model_group,
fallback_model_group,
)
@ -3122,9 +3122,7 @@ class Router:
)
async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915
verbose_router_logger.debug(
f"Inside async function with retries: args - {args}; kwargs - {kwargs}"
)
verbose_router_logger.debug("Inside async function with retries.")
original_function = kwargs.pop("original_function")
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)

View file

@ -59,6 +59,7 @@ class LiteLLMPydanticObjectBase(BaseModel):
class LiteLLMCommonStrings(Enum):
redacted_by_litellm = "redacted by litellm. 'litellm.turn_off_message_logging=True'"
llm_provider_not_provided = "Unmapped LLM provider for this endpoint. You passed model={model}, custom_llm_provider={custom_llm_provider}. Check supported provider and route: https://docs.litellm.ai/docs/providers"
SupportedCacheControls = ["ttl", "s-maxage", "no-cache", "no-store"]

View file

@ -618,7 +618,7 @@ def function_setup( # noqa: PLR0915
details_to_log.pop("prompt", None)
add_breadcrumb(
category="litellm.llm_call",
message=f"Positional Args: {args}, Keyword Args: {details_to_log}",
message=f"Keyword Args: {details_to_log}",
level="info",
)
if "logger_fn" in kwargs:
@ -726,8 +726,8 @@ def function_setup( # noqa: PLR0915
)
return logging_obj, kwargs
except Exception as e:
verbose_logger.error(
f"litellm.utils.py::function_setup() - [Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}"
verbose_logger.exception(
"litellm.utils.py::function_setup() - [Non-Blocking] Error in function_setup"
)
raise e

View file

@ -0,0 +1,156 @@
import ast
import os
import re
def is_venv_directory(path):
"""
Check if the path contains virtual environment directories.
Common virtual environment directory names: venv, env, .env, myenv, .venv
"""
venv_indicators = [
"venv",
"env",
".env",
"myenv",
".venv",
"virtualenv",
"site-packages",
]
path_parts = path.lower().split(os.sep)
return any(indicator in path_parts for indicator in venv_indicators)
class ArgsStringVisitor(ast.NodeVisitor):
"""
AST visitor that finds all instances of '{args}' string usage.
"""
def __init__(self):
self.args_locations = []
self.current_file = None
def set_file(self, filename):
self.current_file = filename
self.args_locations = []
def visit_Str(self, node):
"""Check string literals for {args}"""
if "{args}" in node.s:
self.args_locations.append(
{
"line": node.lineno,
"col": node.col_offset,
"text": node.s,
"file": self.current_file,
}
)
def visit_JoinedStr(self, node):
"""Check f-strings for {args}"""
for value in node.values:
if isinstance(value, ast.FormattedValue):
# Check if the formatted value uses 'args'
if isinstance(value.value, ast.Name) and value.value.id == "args":
self.args_locations.append(
{
"line": node.lineno,
"col": node.col_offset,
"text": "f-string with {args}",
"file": self.current_file,
}
)
def check_file_for_args_string(file_path):
"""
Analyzes a Python file for any usage of '{args}'.
Args:
file_path (str): Path to the Python file to analyze
Returns:
list: List of dictionaries containing information about {args} usage
"""
with open(file_path, "r", encoding="utf-8") as file:
try:
content = file.read()
tree = ast.parse(content)
# First check using AST for more accurate detection in strings
visitor = ArgsStringVisitor()
visitor.set_file(file_path)
visitor.visit(tree)
ast_locations = visitor.args_locations
# Also check using regex for any instances we might have missed
# (like in comments or docstrings)
line_number = 1
additional_locations = []
for line in content.split("\n"):
if "{args}" in line:
# Only add if it's not already caught by the AST visitor
if not any(loc["line"] == line_number for loc in ast_locations):
additional_locations.append(
{
"line": line_number,
"col": line.index("{args}"),
"text": line.strip(),
"file": file_path,
}
)
line_number += 1
return ast_locations + additional_locations
except SyntaxError as e:
print(f"Syntax error in {file_path}: {e}")
return []
def check_directory_for_args_string(directory_path):
"""
Recursively checks all Python files in a directory for '{args}' usage,
excluding virtual environment directories.
Args:
directory_path (str): Path to the directory to check
"""
all_violations = []
for root, dirs, files in os.walk(directory_path):
# Skip virtual environment directories
if is_venv_directory(root):
continue
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
violations = check_file_for_args_string(file_path)
all_violations.extend(violations)
return all_violations
def main():
# Update this path to point to your codebase root directory
# codebase_path = "../../litellm" # Adjust as needed
codebase_path = "./litellm"
violations = check_directory_for_args_string(codebase_path)
if violations:
print("Found '{args}' usage in the following locations:")
for violation in violations:
print(f"- {violation['file']}:{violation['line']} - {violation['text']}")
raise Exception(
f"Found {len(violations)} instances of '{{args}}' usage in the codebase"
)
else:
print("No '{args}' usage found in the codebase.")
if __name__ == "__main__":
main()