mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
* refactor(utils.py): migrate amazon titan config to base config * refactor(utils.py): refactor bedrock meta invoke model translation to use base config * refactor(utils.py): move bedrock ai21 to base config * refactor(utils.py): move bedrock cohere to base config * refactor(utils.py): move bedrock mistral to use base config * refactor(utils.py): move all provider optional param translations to using a config * docs(clientside_auth.md): clarify how to pass vertex region to litellm proxy * fix(utils.py): handle scenario where custom llm provider is none / empty * fix: fix get config * test(test_otel_load_tests.py): widen perf margin * fix(utils.py): fix get provider config check to handle custom llm's * fix(utils.py): fix check
151 lines
5.8 KiB
Python
151 lines
5.8 KiB
Python
import ast
|
|
from typing import List, Set, Dict, Optional
|
|
import sys
|
|
|
|
|
|
class ConfigChecker(ast.NodeVisitor):
|
|
def __init__(self):
|
|
self.errors: List[str] = []
|
|
self.current_provider_block: Optional[str] = None
|
|
self.param_assignments: Dict[str, Set[str]] = {}
|
|
self.map_openai_calls: Set[str] = set()
|
|
self.class_inheritance: Dict[str, List[str]] = {}
|
|
|
|
def get_full_name(self, node):
|
|
"""Recursively extract the full name from a node."""
|
|
if isinstance(node, ast.Name):
|
|
return node.id
|
|
elif isinstance(node, ast.Attribute):
|
|
base = self.get_full_name(node.value)
|
|
if base:
|
|
return f"{base}.{node.attr}"
|
|
return None
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef):
|
|
# Record class inheritance
|
|
bases = [base.id for base in node.bases if isinstance(base, ast.Name)]
|
|
print(f"Found class {node.name} with bases {bases}")
|
|
self.class_inheritance[node.name] = bases
|
|
self.generic_visit(node)
|
|
|
|
def visit_Call(self, node: ast.Call):
|
|
# Check for map_openai_params calls
|
|
if (
|
|
isinstance(node.func, ast.Attribute)
|
|
and node.func.attr == "map_openai_params"
|
|
):
|
|
if isinstance(node.func.value, ast.Name):
|
|
config_name = node.func.value.id
|
|
self.map_openai_calls.add(config_name)
|
|
self.generic_visit(node)
|
|
|
|
def visit_If(self, node: ast.If):
|
|
# Detect custom_llm_provider blocks
|
|
provider = self._extract_provider_from_if(node)
|
|
if provider:
|
|
old_provider = self.current_provider_block
|
|
self.current_provider_block = provider
|
|
self.generic_visit(node)
|
|
self.current_provider_block = old_provider
|
|
else:
|
|
self.generic_visit(node)
|
|
|
|
def visit_Assign(self, node: ast.Assign):
|
|
# Track assignments to optional_params
|
|
if self.current_provider_block and len(node.targets) == 1:
|
|
target = node.targets[0]
|
|
if isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name):
|
|
if target.value.id == "optional_params":
|
|
if isinstance(target.slice, ast.Constant):
|
|
key = target.slice.value
|
|
if self.current_provider_block not in self.param_assignments:
|
|
self.param_assignments[self.current_provider_block] = set()
|
|
self.param_assignments[self.current_provider_block].add(key)
|
|
self.generic_visit(node)
|
|
|
|
def _extract_provider_from_if(self, node: ast.If) -> Optional[str]:
|
|
"""Extract the provider name from an if condition checking custom_llm_provider"""
|
|
if isinstance(node.test, ast.Compare):
|
|
if len(node.test.ops) == 1 and isinstance(node.test.ops[0], ast.Eq):
|
|
if (
|
|
isinstance(node.test.left, ast.Name)
|
|
and node.test.left.id == "custom_llm_provider"
|
|
):
|
|
if isinstance(node.test.comparators[0], ast.Constant):
|
|
return node.test.comparators[0].value
|
|
return None
|
|
|
|
def check_patterns(self) -> List[str]:
|
|
# Check if all configs using map_openai_params inherit from BaseConfig
|
|
for config_name in self.map_openai_calls:
|
|
print(f"Checking config: {config_name}")
|
|
if (
|
|
config_name not in self.class_inheritance
|
|
or "BaseConfig" not in self.class_inheritance[config_name]
|
|
):
|
|
# Retrieve the associated class name, if any
|
|
class_name = next(
|
|
(
|
|
cls
|
|
for cls, bases in self.class_inheritance.items()
|
|
if config_name in bases
|
|
),
|
|
"Unknown Class",
|
|
)
|
|
self.errors.append(
|
|
f"Error: {config_name} calls map_openai_params but doesn't inherit from BaseConfig. "
|
|
f"It is used in the class: {class_name}"
|
|
)
|
|
|
|
# Check for parameter assignments in provider blocks
|
|
for provider, params in self.param_assignments.items():
|
|
# You can customize which parameters should raise warnings for each provider
|
|
for param in params:
|
|
if param not in self._get_allowed_params(provider):
|
|
self.errors.append(
|
|
f"Warning: Parameter '{param}' is directly assigned in {provider} block. "
|
|
f"Consider using a config class instead."
|
|
)
|
|
|
|
return self.errors
|
|
|
|
def _get_allowed_params(self, provider: str) -> Set[str]:
|
|
"""Define allowed direct parameter assignments for each provider"""
|
|
# You can customize this based on your requirements
|
|
common_allowed = {"stream", "api_key", "api_base"}
|
|
provider_specific = {
|
|
"anthropic": {"api_version"},
|
|
"openai": {"organization"},
|
|
# Add more providers and their allowed params here
|
|
}
|
|
return common_allowed.union(provider_specific.get(provider, set()))
|
|
|
|
|
|
def check_file(file_path: str) -> List[str]:
|
|
with open(file_path, "r") as file:
|
|
tree = ast.parse(file.read())
|
|
|
|
checker = ConfigChecker()
|
|
for node in tree.body:
|
|
if isinstance(node, ast.FunctionDef) and node.name == "get_optional_params":
|
|
checker.visit(node)
|
|
break # No need to visit other functions
|
|
return checker.check_patterns()
|
|
|
|
|
|
def main():
|
|
file_path = "../../litellm/utils.py"
|
|
errors = check_file(file_path)
|
|
|
|
if errors:
|
|
print("\nFound the following issues:")
|
|
for error in errors:
|
|
print(f"- {error}")
|
|
sys.exit(1)
|
|
else:
|
|
print("No issues found!")
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|