feat(router.py): support mock testing content policy + context window fallbacks

This commit is contained in:
Krrish Dholakia 2024-06-25 10:57:32 -07:00
parent 6f51da4e78
commit 0396d484fb
2 changed files with 76 additions and 20 deletions

View file

@ -1,24 +1,54 @@
model_list: # model_list:
- model_name: my-fake-model # - model_name: my-fake-model
litellm_params: # litellm_params:
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0 # model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
api_key: my-fake-key # api_key: my-fake-key
aws_bedrock_runtime_endpoint: http://127.0.0.1:8000 # aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
mock_response: "Hello world 1" # mock_response: "Hello world 1"
model_info: # model_info:
max_input_tokens: 0 # trigger context window fallback # max_input_tokens: 0 # trigger context window fallback
- model_name: my-fake-model # - model_name: my-fake-model
litellm_params: # litellm_params:
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0 # model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
api_key: my-fake-key # api_key: my-fake-key
aws_bedrock_runtime_endpoint: http://127.0.0.1:8000 # aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
mock_response: "Hello world 2" # mock_response: "Hello world 2"
model_info: # model_info:
max_input_tokens: 0 # max_input_tokens: 0
router_settings: # router_settings:
enable_pre_call_checks: True # enable_pre_call_checks: True
# litellm_settings:
# failure_callback: ["langfuse"]
model_list:
- model_name: summarize
litellm_params:
model: openai/gpt-4o
rpm: 10000
tpm: 12000000
api_key: os.environ/OPENAI_API_KEY
mock_response: Hello world 1
- model_name: summarize-l
litellm_params:
model: claude-3-5-sonnet-20240620
rpm: 4000
tpm: 400000
api_key: os.environ/ANTHROPIC_API_KEY
mock_response: Hello world 2
litellm_settings: litellm_settings:
failure_callback: ["langfuse"] num_retries: 3
request_timeout: 120
allowed_fails: 3
# fallbacks: [{"summarize": ["summarize-l", "summarize-xl"]}, {"summarize-l": ["summarize-xl"]}]
context_window_fallbacks: [{"summarize": ["summarize-l", "summarize-xl"]}, {"summarize-l": ["summarize-xl"]}]
router_settings:
routing_strategy: simple-shuffle
enable_pre_call_checks: true.

View file

@ -2117,6 +2117,12 @@ class Router:
If it fails after num_retries, fall back to another model group If it fails after num_retries, fall back to another model group
""" """
mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None) mock_testing_fallbacks = kwargs.pop("mock_testing_fallbacks", None)
mock_testing_context_fallbacks = kwargs.pop(
"mock_testing_context_fallbacks", None
)
mock_testing_content_policy_fallbacks = kwargs.pop(
"mock_testing_content_policy_fallbacks", None
)
model_group = kwargs.get("model") model_group = kwargs.get("model")
fallbacks = kwargs.get("fallbacks", self.fallbacks) fallbacks = kwargs.get("fallbacks", self.fallbacks)
context_window_fallbacks = kwargs.get( context_window_fallbacks = kwargs.get(
@ -2130,6 +2136,26 @@ class Router:
raise Exception( raise Exception(
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}" f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
) )
elif (
mock_testing_context_fallbacks is not None
and mock_testing_context_fallbacks is True
):
raise litellm.ContextWindowExceededError(
model=model_group,
llm_provider="",
message=f"This is a mock exception for model={model_group}, to trigger a fallback. \
Context_Window_Fallbacks={context_window_fallbacks}",
)
elif (
mock_testing_content_policy_fallbacks is not None
and mock_testing_content_policy_fallbacks is True
):
raise litellm.ContentPolicyViolationError(
model=model_group,
llm_provider="",
message=f"This is a mock exception for model={model_group}, to trigger a fallback. \
Context_Policy_Fallbacks={content_policy_fallbacks}",
)
response = await self.async_function_with_retries(*args, **kwargs) response = await self.async_function_with_retries(*args, **kwargs)
verbose_router_logger.debug(f"Async Response: {response}") verbose_router_logger.debug(f"Async Response: {response}")