feat(router.py): support mock testing content policy + context window fallbacks

This commit is contained in:
Krrish Dholakia 2024-06-25 10:57:32 -07:00
parent e215dc86c1
commit c46b229202
2 changed files with 76 additions and 20 deletions

View file

@ -1,24 +1,54 @@
model_list:
- model_name: my-fake-model
litellm_params:
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
api_key: my-fake-key
aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
mock_response: "Hello world 1"
model_info:
max_input_tokens: 0 # trigger context window fallback
- model_name: my-fake-model
litellm_params:
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
api_key: my-fake-key
aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
mock_response: "Hello world 2"
model_info:
max_input_tokens: 0
# model_list:
# - model_name: my-fake-model
# litellm_params:
# model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
# api_key: my-fake-key
# aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
# mock_response: "Hello world 1"
# model_info:
# max_input_tokens: 0 # trigger context window fallback
# - model_name: my-fake-model
# litellm_params:
# model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
# api_key: my-fake-key
# aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
# mock_response: "Hello world 2"
# model_info:
# max_input_tokens: 0
router_settings:
enable_pre_call_checks: True
# router_settings:
# enable_pre_call_checks: True
# litellm_settings:
# failure_callback: ["langfuse"]
model_list:
- model_name: summarize
litellm_params:
model: openai/gpt-4o
rpm: 10000
tpm: 12000000
api_key: os.environ/OPENAI_API_KEY
mock_response: Hello world 1
- model_name: summarize-l
litellm_params:
model: claude-3-5-sonnet-20240620
rpm: 4000
tpm: 400000
api_key: os.environ/ANTHROPIC_API_KEY
mock_response: Hello world 2
litellm_settings:
failure_callback: ["langfuse"]
num_retries: 3
request_timeout: 120
allowed_fails: 3
# fallbacks: [{"summarize": ["summarize-l", "summarize-xl"]}, {"summarize-l": ["summarize-xl"]}]
context_window_fallbacks: [{"summarize": ["summarize-l", "summarize-xl"]}, {"summarize-l": ["summarize-xl"]}]
router_settings:
routing_strategy: simple-shuffle
enable_pre_call_checks: true.

View file

@ -2117,6 +2117,12 @@ class Router:
If it fails after num_retries, fall back to another model group
"""
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
mock_testing_context_fallbacks = kwargs.pop(
"mock_testing_context_fallbacks", None
)
mock_testing_content_policy_fallbacks = kwargs.pop(
"mock_testing_content_policy_fallbacks", None
)
model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get(
@ -2130,6 +2136,26 @@ class Router:
raise Exception(
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
)
elif (
mock_testing_context_fallbacks is not None
and mock_testing_context_fallbacks is True
):
raise litellm.ContextWindowExceededError(
model=model_group,
llm_provider="",
message=f"This is a mock exception for model={model_group}, to trigger a fallback. \
Context_Window_Fallbacks={context_window_fallbacks}",
)
elif (
mock_testing_content_policy_fallbacks is not None
and mock_testing_content_policy_fallbacks is True
):
raise litellm.ContentPolicyViolationError(
model=model_group,
llm_provider="",
message=f"This is a mock exception for model={model_group}, to trigger a fallback. \
Context_Policy_Fallbacks={content_policy_fallbacks}",
)
response = await self.async_function_with_retries(*args, **kwargs)
verbose_router_logger.debug(f"Async Response: {response}")