mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(main.py): fix key leak error when unknown provider given (#8556)
* fix(main.py): fix key leak error when unknown provider given don't return passed in args if unknown route on embedding * fix(main.py): remove instances of {args} being passed in exception prevent potential key leaks * test(code_coverage/prevent_key_leaks_in_codebase.py): ban usage of {args} in codebase * fix: fix linting errors * fix: remove unused variable
This commit is contained in:
parent
c6026ea6f9
commit
a9276f27f9
8 changed files with 193 additions and 30 deletions
|
@ -1058,6 +1058,7 @@ jobs:
|
|||
- run: python ./tests/code_coverage_tests/ensure_async_clients_test.py
|
||||
- run: python ./tests/code_coverage_tests/enforce_llms_folder_style.py
|
||||
- run: python ./tests/documentation_tests/test_circular_imports.py
|
||||
- run: python ./tests/code_coverage_tests/prevent_key_leaks_in_exceptions.py
|
||||
- run: helm lint ./deploy/charts/litellm-helm
|
||||
|
||||
db_migration_disable_update_check:
|
||||
|
|
|
@ -14,6 +14,8 @@ from typing import Optional
|
|||
import httpx
|
||||
import openai
|
||||
|
||||
from litellm.types.utils import LiteLLMCommonStrings
|
||||
|
||||
|
||||
class AuthenticationError(openai.AuthenticationError): # type: ignore
|
||||
def __init__(
|
||||
|
@ -790,3 +792,16 @@ class MockException(openai.APIError):
|
|||
if request is None:
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
super().__init__(self.message, request=request, body=None) # type: ignore
|
||||
|
||||
|
||||
class LiteLLMUnknownProvider(BadRequestError):
|
||||
def __init__(self, model: str, custom_llm_provider: Optional[str] = None):
|
||||
self.message = LiteLLMCommonStrings.llm_provider_not_provided.value.format(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
super().__init__(
|
||||
self.message, model=model, llm_provider=custom_llm_provider, response=None
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
|
|
@ -50,6 +50,7 @@ from litellm import ( # type: ignore
|
|||
get_litellm_params,
|
||||
get_optional_params,
|
||||
)
|
||||
from litellm.exceptions import LiteLLMUnknownProvider
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_for_health_check
|
||||
from litellm.litellm_core_utils.health_check_utils import (
|
||||
|
@ -3036,8 +3037,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
custom_handler = item["custom_handler"]
|
||||
|
||||
if custom_handler is None:
|
||||
raise ValueError(
|
||||
f"Unable to map your input to a model. Check your input - {args}"
|
||||
raise LiteLLMUnknownProvider(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
## ROUTE LLM CALL ##
|
||||
|
@ -3075,8 +3076,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to map your input to a model. Check your input - {args}"
|
||||
raise LiteLLMUnknownProvider(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -3263,17 +3264,10 @@ def embedding( # noqa: PLR0915
|
|||
"""
|
||||
azure = kwargs.get("azure", None)
|
||||
client = kwargs.pop("client", None)
|
||||
rpm = kwargs.pop("rpm", None)
|
||||
tpm = kwargs.pop("tpm", None)
|
||||
max_retries = kwargs.get("max_retries", None)
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
cooldown_time = kwargs.get("cooldown_time", None)
|
||||
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
|
||||
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
|
||||
azure_ad_token_provider = kwargs.pop("azure_ad_token_provider", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
aembedding = kwargs.get("aembedding", None)
|
||||
extra_headers = kwargs.get("extra_headers", None)
|
||||
headers = kwargs.get("headers", None)
|
||||
|
@ -3366,7 +3360,6 @@ def embedding( # noqa: PLR0915
|
|||
|
||||
if azure is True or custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
|
||||
|
||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
|
||||
|
@ -3439,7 +3432,6 @@ def embedding( # noqa: PLR0915
|
|||
if extra_headers is not None:
|
||||
optional_params["extra_headers"] = extra_headers
|
||||
|
||||
api_type = "openai"
|
||||
api_version = None
|
||||
|
||||
## EMBEDDING CALL
|
||||
|
@ -3850,14 +3842,16 @@ def embedding( # noqa: PLR0915
|
|||
aembedding=aembedding,
|
||||
)
|
||||
else:
|
||||
args = locals()
|
||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||
raise LiteLLMUnknownProvider(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
if response is not None and hasattr(response, "_hidden_params"):
|
||||
response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
if response is None:
|
||||
args = locals()
|
||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||
raise LiteLLMUnknownProvider(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
|
@ -4667,8 +4661,8 @@ def image_generation( # noqa: PLR0915
|
|||
custom_handler = item["custom_handler"]
|
||||
|
||||
if custom_handler is None:
|
||||
raise ValueError(
|
||||
f"Unable to map your input to a model. Check your input - {args}"
|
||||
raise LiteLLMUnknownProvider(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
## ROUTE LLM CALL ##
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
model_list:
|
||||
- model_name: azure-gpt-35-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
timeout: 0.000000001
|
||||
model: topaz/chatgpt-v-2
|
||||
api_key: os.environ/AZURE_API_KEY
|
|
@ -3057,7 +3057,7 @@ class Router:
|
|||
|
||||
if hasattr(original_exception, "message"):
|
||||
# add the available fallbacks to the exception
|
||||
original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
|
||||
original_exception.message += ". Received Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore
|
||||
model_group,
|
||||
fallback_model_group,
|
||||
)
|
||||
|
@ -3122,9 +3122,7 @@ class Router:
|
|||
)
|
||||
|
||||
async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915
|
||||
verbose_router_logger.debug(
|
||||
f"Inside async function with retries: args - {args}; kwargs - {kwargs}"
|
||||
)
|
||||
verbose_router_logger.debug("Inside async function with retries.")
|
||||
original_function = kwargs.pop("original_function")
|
||||
fallbacks = kwargs.pop("fallbacks", self.fallbacks)
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
|
|
|
@ -59,6 +59,7 @@ class LiteLLMPydanticObjectBase(BaseModel):
|
|||
|
||||
class LiteLLMCommonStrings(Enum):
|
||||
redacted_by_litellm = "redacted by litellm. 'litellm.turn_off_message_logging=True'"
|
||||
llm_provider_not_provided = "Unmapped LLM provider for this endpoint. You passed model={model}, custom_llm_provider={custom_llm_provider}. Check supported provider and route: https://docs.litellm.ai/docs/providers"
|
||||
|
||||
|
||||
SupportedCacheControls = ["ttl", "s-maxage", "no-cache", "no-store"]
|
||||
|
|
|
@ -618,7 +618,7 @@ def function_setup( # noqa: PLR0915
|
|||
details_to_log.pop("prompt", None)
|
||||
add_breadcrumb(
|
||||
category="litellm.llm_call",
|
||||
message=f"Positional Args: {args}, Keyword Args: {details_to_log}",
|
||||
message=f"Keyword Args: {details_to_log}",
|
||||
level="info",
|
||||
)
|
||||
if "logger_fn" in kwargs:
|
||||
|
@ -726,8 +726,8 @@ def function_setup( # noqa: PLR0915
|
|||
)
|
||||
return logging_obj, kwargs
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"litellm.utils.py::function_setup() - [Non-Blocking] {traceback.format_exc()}; args - {args}; kwargs - {kwargs}"
|
||||
verbose_logger.exception(
|
||||
"litellm.utils.py::function_setup() - [Non-Blocking] Error in function_setup"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
|
156
tests/code_coverage_tests/prevent_key_leaks_in_exceptions.py
Normal file
156
tests/code_coverage_tests/prevent_key_leaks_in_exceptions.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
import ast
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def is_venv_directory(path):
|
||||
"""
|
||||
Check if the path contains virtual environment directories.
|
||||
Common virtual environment directory names: venv, env, .env, myenv, .venv
|
||||
"""
|
||||
venv_indicators = [
|
||||
"venv",
|
||||
"env",
|
||||
".env",
|
||||
"myenv",
|
||||
".venv",
|
||||
"virtualenv",
|
||||
"site-packages",
|
||||
]
|
||||
|
||||
path_parts = path.lower().split(os.sep)
|
||||
return any(indicator in path_parts for indicator in venv_indicators)
|
||||
|
||||
|
||||
class ArgsStringVisitor(ast.NodeVisitor):
|
||||
"""
|
||||
AST visitor that finds all instances of '{args}' string usage.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.args_locations = []
|
||||
self.current_file = None
|
||||
|
||||
def set_file(self, filename):
|
||||
self.current_file = filename
|
||||
self.args_locations = []
|
||||
|
||||
def visit_Str(self, node):
|
||||
"""Check string literals for {args}"""
|
||||
if "{args}" in node.s:
|
||||
self.args_locations.append(
|
||||
{
|
||||
"line": node.lineno,
|
||||
"col": node.col_offset,
|
||||
"text": node.s,
|
||||
"file": self.current_file,
|
||||
}
|
||||
)
|
||||
|
||||
def visit_JoinedStr(self, node):
|
||||
"""Check f-strings for {args}"""
|
||||
for value in node.values:
|
||||
if isinstance(value, ast.FormattedValue):
|
||||
# Check if the formatted value uses 'args'
|
||||
if isinstance(value.value, ast.Name) and value.value.id == "args":
|
||||
self.args_locations.append(
|
||||
{
|
||||
"line": node.lineno,
|
||||
"col": node.col_offset,
|
||||
"text": "f-string with {args}",
|
||||
"file": self.current_file,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def check_file_for_args_string(file_path):
|
||||
"""
|
||||
Analyzes a Python file for any usage of '{args}'.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the Python file to analyze
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing information about {args} usage
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
try:
|
||||
content = file.read()
|
||||
tree = ast.parse(content)
|
||||
|
||||
# First check using AST for more accurate detection in strings
|
||||
visitor = ArgsStringVisitor()
|
||||
visitor.set_file(file_path)
|
||||
visitor.visit(tree)
|
||||
ast_locations = visitor.args_locations
|
||||
|
||||
# Also check using regex for any instances we might have missed
|
||||
# (like in comments or docstrings)
|
||||
line_number = 1
|
||||
additional_locations = []
|
||||
|
||||
for line in content.split("\n"):
|
||||
if "{args}" in line:
|
||||
# Only add if it's not already caught by the AST visitor
|
||||
if not any(loc["line"] == line_number for loc in ast_locations):
|
||||
additional_locations.append(
|
||||
{
|
||||
"line": line_number,
|
||||
"col": line.index("{args}"),
|
||||
"text": line.strip(),
|
||||
"file": file_path,
|
||||
}
|
||||
)
|
||||
line_number += 1
|
||||
|
||||
return ast_locations + additional_locations
|
||||
|
||||
except SyntaxError as e:
|
||||
print(f"Syntax error in {file_path}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def check_directory_for_args_string(directory_path):
|
||||
"""
|
||||
Recursively checks all Python files in a directory for '{args}' usage,
|
||||
excluding virtual environment directories.
|
||||
|
||||
Args:
|
||||
directory_path (str): Path to the directory to check
|
||||
"""
|
||||
all_violations = []
|
||||
|
||||
for root, dirs, files in os.walk(directory_path):
|
||||
# Skip virtual environment directories
|
||||
if is_venv_directory(root):
|
||||
continue
|
||||
|
||||
for file in files:
|
||||
if file.endswith(".py"):
|
||||
file_path = os.path.join(root, file)
|
||||
violations = check_file_for_args_string(file_path)
|
||||
all_violations.extend(violations)
|
||||
|
||||
return all_violations
|
||||
|
||||
|
||||
def main():
|
||||
# Update this path to point to your codebase root directory
|
||||
# codebase_path = "../../litellm" # Adjust as needed
|
||||
codebase_path = "./litellm"
|
||||
|
||||
violations = check_directory_for_args_string(codebase_path)
|
||||
|
||||
if violations:
|
||||
print("Found '{args}' usage in the following locations:")
|
||||
for violation in violations:
|
||||
print(f"- {violation['file']}:{violation['line']} - {violation['text']}")
|
||||
raise Exception(
|
||||
f"Found {len(violations)} instances of '{{args}}' usage in the codebase"
|
||||
)
|
||||
else:
|
||||
print("No '{args}' usage found in the codebase.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Add table
Add a link
Reference in a new issue