litellm-mirror/tests/documentation_tests/test_optional_params.py
Krish Dholakia 31ace870a2
Litellm dev 12 28 2024 p1 (#7463)
* refactor(utils.py): migrate amazon titan config to base config

* refactor(utils.py): refactor bedrock meta invoke model translation to use base config

* refactor(utils.py): move bedrock ai21 to base config

* refactor(utils.py): move bedrock cohere to base config

* refactor(utils.py): move bedrock mistral to use base config

* refactor(utils.py): move all provider optional param translations to using a config

* docs(clientside_auth.md): clarify how to pass vertex region to litellm proxy

* fix(utils.py): handle scenario where custom llm provider is none / empty

* fix: fix get config

* test(test_otel_load_tests.py): widen perf margin

* fix(utils.py): fix get provider config check to handle custom llm's

* fix(utils.py): fix check
2024-12-28 20:26:00 -08:00

151 lines
5.8 KiB
Python

import ast
from typing import List, Set, Dict, Optional
import sys
class ConfigChecker(ast.NodeVisitor):
def __init__(self):
self.errors: List[str] = []
self.current_provider_block: Optional[str] = None
self.param_assignments: Dict[str, Set[str]] = {}
self.map_openai_calls: Set[str] = set()
self.class_inheritance: Dict[str, List[str]] = {}
def get_full_name(self, node):
"""Recursively extract the full name from a node."""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
base = self.get_full_name(node.value)
if base:
return f"{base}.{node.attr}"
return None
def visit_ClassDef(self, node: ast.ClassDef):
# Record class inheritance
bases = [base.id for base in node.bases if isinstance(base, ast.Name)]
print(f"Found class {node.name} with bases {bases}")
self.class_inheritance[node.name] = bases
self.generic_visit(node)
def visit_Call(self, node: ast.Call):
# Check for map_openai_params calls
if (
isinstance(node.func, ast.Attribute)
and node.func.attr == "map_openai_params"
):
if isinstance(node.func.value, ast.Name):
config_name = node.func.value.id
self.map_openai_calls.add(config_name)
self.generic_visit(node)
def visit_If(self, node: ast.If):
# Detect custom_llm_provider blocks
provider = self._extract_provider_from_if(node)
if provider:
old_provider = self.current_provider_block
self.current_provider_block = provider
self.generic_visit(node)
self.current_provider_block = old_provider
else:
self.generic_visit(node)
def visit_Assign(self, node: ast.Assign):
# Track assignments to optional_params
if self.current_provider_block and len(node.targets) == 1:
target = node.targets[0]
if isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name):
if target.value.id == "optional_params":
if isinstance(target.slice, ast.Constant):
key = target.slice.value
if self.current_provider_block not in self.param_assignments:
self.param_assignments[self.current_provider_block] = set()
self.param_assignments[self.current_provider_block].add(key)
self.generic_visit(node)
def _extract_provider_from_if(self, node: ast.If) -> Optional[str]:
"""Extract the provider name from an if condition checking custom_llm_provider"""
if isinstance(node.test, ast.Compare):
if len(node.test.ops) == 1 and isinstance(node.test.ops[0], ast.Eq):
if (
isinstance(node.test.left, ast.Name)
and node.test.left.id == "custom_llm_provider"
):
if isinstance(node.test.comparators[0], ast.Constant):
return node.test.comparators[0].value
return None
def check_patterns(self) -> List[str]:
# Check if all configs using map_openai_params inherit from BaseConfig
for config_name in self.map_openai_calls:
print(f"Checking config: {config_name}")
if (
config_name not in self.class_inheritance
or "BaseConfig" not in self.class_inheritance[config_name]
):
# Retrieve the associated class name, if any
class_name = next(
(
cls
for cls, bases in self.class_inheritance.items()
if config_name in bases
),
"Unknown Class",
)
self.errors.append(
f"Error: {config_name} calls map_openai_params but doesn't inherit from BaseConfig. "
f"It is used in the class: {class_name}"
)
# Check for parameter assignments in provider blocks
for provider, params in self.param_assignments.items():
# You can customize which parameters should raise warnings for each provider
for param in params:
if param not in self._get_allowed_params(provider):
self.errors.append(
f"Warning: Parameter '{param}' is directly assigned in {provider} block. "
f"Consider using a config class instead."
)
return self.errors
def _get_allowed_params(self, provider: str) -> Set[str]:
"""Define allowed direct parameter assignments for each provider"""
# You can customize this based on your requirements
common_allowed = {"stream", "api_key", "api_base"}
provider_specific = {
"anthropic": {"api_version"},
"openai": {"organization"},
# Add more providers and their allowed params here
}
return common_allowed.union(provider_specific.get(provider, set()))
def check_file(file_path: str) -> List[str]:
with open(file_path, "r") as file:
tree = ast.parse(file.read())
checker = ConfigChecker()
for node in tree.body:
if isinstance(node, ast.FunctionDef) and node.name == "get_optional_params":
checker.visit(node)
break # No need to visit other functions
return checker.check_patterns()
def main():
file_path = "../../litellm/utils.py"
errors = check_file(file_path)
if errors:
print("\nFound the following issues:")
for error in errors:
print(f"- {error}")
sys.exit(1)
else:
print("No issues found!")
sys.exit(0)
if __name__ == "__main__":
main()