import ast from typing import List, Set, Dict, Optional import sys class ConfigChecker(ast.NodeVisitor): def __init__(self): self.errors: List[str] = [] self.current_provider_block: Optional[str] = None self.param_assignments: Dict[str, Set[str]] = {} self.map_openai_calls: Set[str] = set() self.class_inheritance: Dict[str, List[str]] = {} def get_full_name(self, node): """Recursively extract the full name from a node.""" if isinstance(node, ast.Name): return node.id elif isinstance(node, ast.Attribute): base = self.get_full_name(node.value) if base: return f"{base}.{node.attr}" return None def visit_ClassDef(self, node: ast.ClassDef): # Record class inheritance bases = [base.id for base in node.bases if isinstance(base, ast.Name)] print(f"Found class {node.name} with bases {bases}") self.class_inheritance[node.name] = bases self.generic_visit(node) def visit_Call(self, node: ast.Call): # Check for map_openai_params calls if ( isinstance(node.func, ast.Attribute) and node.func.attr == "map_openai_params" ): if isinstance(node.func.value, ast.Name): config_name = node.func.value.id self.map_openai_calls.add(config_name) self.generic_visit(node) def visit_If(self, node: ast.If): # Detect custom_llm_provider blocks provider = self._extract_provider_from_if(node) if provider: old_provider = self.current_provider_block self.current_provider_block = provider self.generic_visit(node) self.current_provider_block = old_provider else: self.generic_visit(node) def visit_Assign(self, node: ast.Assign): # Track assignments to optional_params if self.current_provider_block and len(node.targets) == 1: target = node.targets[0] if isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name): if target.value.id == "optional_params": if isinstance(target.slice, ast.Constant): key = target.slice.value if self.current_provider_block not in self.param_assignments: self.param_assignments[self.current_provider_block] = set() self.param_assignments[self.current_provider_block].add(key) self.generic_visit(node) def _extract_provider_from_if(self, node: ast.If) -> Optional[str]: """Extract the provider name from an if condition checking custom_llm_provider""" if isinstance(node.test, ast.Compare): if len(node.test.ops) == 1 and isinstance(node.test.ops[0], ast.Eq): if ( isinstance(node.test.left, ast.Name) and node.test.left.id == "custom_llm_provider" ): if isinstance(node.test.comparators[0], ast.Constant): return node.test.comparators[0].value return None def check_patterns(self) -> List[str]: # Check if all configs using map_openai_params inherit from BaseConfig for config_name in self.map_openai_calls: print(f"Checking config: {config_name}") if ( config_name not in self.class_inheritance or "BaseConfig" not in self.class_inheritance[config_name] ): # Retrieve the associated class name, if any class_name = next( ( cls for cls, bases in self.class_inheritance.items() if config_name in bases ), "Unknown Class", ) self.errors.append( f"Error: {config_name} calls map_openai_params but doesn't inherit from BaseConfig. " f"It is used in the class: {class_name}" ) # Check for parameter assignments in provider blocks for provider, params in self.param_assignments.items(): # You can customize which parameters should raise warnings for each provider for param in params: if param not in self._get_allowed_params(provider): self.errors.append( f"Warning: Parameter '{param}' is directly assigned in {provider} block. " f"Consider using a config class instead." ) return self.errors def _get_allowed_params(self, provider: str) -> Set[str]: """Define allowed direct parameter assignments for each provider""" # You can customize this based on your requirements common_allowed = {"stream", "api_key", "api_base"} provider_specific = { "anthropic": {"api_version"}, "openai": {"organization"}, # Add more providers and their allowed params here } return common_allowed.union(provider_specific.get(provider, set())) def check_file(file_path: str) -> List[str]: with open(file_path, "r") as file: tree = ast.parse(file.read()) checker = ConfigChecker() for node in tree.body: if isinstance(node, ast.FunctionDef) and node.name == "get_optional_params": checker.visit(node) break # No need to visit other functions return checker.check_patterns() def main(): file_path = "../../litellm/utils.py" errors = check_file(file_path) if errors: print("\nFound the following issues:") for error in errors: print(f"- {error}") sys.exit(1) else: print("No issues found!") sys.exit(0) if __name__ == "__main__": main()