mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(router.py): use user-defined model_input_tokens for pre-call filter checks
This commit is contained in:
parent
123477b55a
commit
f5fbdf0fee
3 changed files with 58 additions and 5 deletions
|
@ -4,7 +4,17 @@ model_list:
|
||||||
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
|
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
|
||||||
api_key: my-fake-key
|
api_key: my-fake-key
|
||||||
aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
|
aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
|
||||||
|
mock_response: "Hello world 1"
|
||||||
|
model_info:
|
||||||
|
max_input_tokens: 0 # trigger context window fallback
|
||||||
|
- model_name: my-fake-model
|
||||||
|
litellm_params:
|
||||||
|
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
|
||||||
|
api_key: my-fake-key
|
||||||
|
aws_bedrock_runtime_endpoint: http://127.0.0.1:8000
|
||||||
|
mock_response: "Hello world 2"
|
||||||
|
model_info:
|
||||||
|
max_input_tokens: 0
|
||||||
|
|
||||||
litellm_settings:
|
router_settings:
|
||||||
success_callback: ["langfuse"]
|
enable_pre_call_checks: True
|
||||||
failure_callback: ["langfuse"]
|
|
||||||
|
|
|
@ -404,6 +404,7 @@ class Router:
|
||||||
litellm.failure_callback = [self.deployment_callback_on_failure]
|
litellm.failure_callback = [self.deployment_callback_on_failure]
|
||||||
print( # noqa
|
print( # noqa
|
||||||
f"Intialized router with Routing strategy: {self.routing_strategy}\n\n"
|
f"Intialized router with Routing strategy: {self.routing_strategy}\n\n"
|
||||||
|
f"Routing enable_pre_call_checks: {self.enable_pre_call_checks}\n\n"
|
||||||
f"Routing fallbacks: {self.fallbacks}\n\n"
|
f"Routing fallbacks: {self.fallbacks}\n\n"
|
||||||
f"Routing content fallbacks: {self.content_policy_fallbacks}\n\n"
|
f"Routing content fallbacks: {self.content_policy_fallbacks}\n\n"
|
||||||
f"Routing context window fallbacks: {self.context_window_fallbacks}\n\n"
|
f"Routing context window fallbacks: {self.context_window_fallbacks}\n\n"
|
||||||
|
@ -3915,9 +3916,38 @@ class Router:
|
||||||
raise Exception("Model invalid format - {}".format(type(model)))
|
raise Exception("Model invalid format - {}".format(type(model)))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_router_model_info(self, deployment: dict) -> ModelMapInfo:
|
||||||
|
"""
|
||||||
|
For a given model id, return the model info (max tokens, input cost, output cost, etc.).
|
||||||
|
|
||||||
|
Augment litellm info with additional params set in `model_info`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
- ModelInfo - If found -> typed dict with max tokens, input cost, etc.
|
||||||
|
"""
|
||||||
|
## SET MODEL NAME
|
||||||
|
base_model = deployment.get("model_info", {}).get("base_model", None)
|
||||||
|
if base_model is None:
|
||||||
|
base_model = deployment.get("litellm_params", {}).get("base_model", None)
|
||||||
|
model = base_model or deployment.get("litellm_params", {}).get("model", None)
|
||||||
|
|
||||||
|
## GET LITELLM MODEL INFO
|
||||||
|
model_info = litellm.get_model_info(model=model)
|
||||||
|
|
||||||
|
## CHECK USER SET MODEL INFO
|
||||||
|
user_model_info = deployment.get("model_info", {})
|
||||||
|
|
||||||
|
model_info.update(user_model_info)
|
||||||
|
|
||||||
|
return model_info
|
||||||
|
|
||||||
def get_model_info(self, id: str) -> Optional[dict]:
|
def get_model_info(self, id: str) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
For a given model id, return the model info
|
For a given model id, return the model info
|
||||||
|
|
||||||
|
Returns
|
||||||
|
- dict: the model in list with 'model_name', 'litellm_params', Optional['model_info']
|
||||||
|
- None: could not find deployment in list
|
||||||
"""
|
"""
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
if "model_info" in model and "id" in model["model_info"]:
|
if "model_info" in model and "id" in model["model_info"]:
|
||||||
|
@ -4307,6 +4337,7 @@ class Router:
|
||||||
return _returned_deployments
|
return _returned_deployments
|
||||||
|
|
||||||
_context_window_error = False
|
_context_window_error = False
|
||||||
|
_potential_error_str = ""
|
||||||
_rate_limit_error = False
|
_rate_limit_error = False
|
||||||
|
|
||||||
## get model group RPM ##
|
## get model group RPM ##
|
||||||
|
@ -4327,7 +4358,7 @@ class Router:
|
||||||
model = base_model or deployment.get("litellm_params", {}).get(
|
model = base_model or deployment.get("litellm_params", {}).get(
|
||||||
"model", None
|
"model", None
|
||||||
)
|
)
|
||||||
model_info = litellm.get_model_info(model=model)
|
model_info = self.get_router_model_info(deployment=deployment)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(model_info, dict)
|
isinstance(model_info, dict)
|
||||||
|
@ -4339,6 +4370,11 @@ class Router:
|
||||||
):
|
):
|
||||||
invalid_model_indices.append(idx)
|
invalid_model_indices.append(idx)
|
||||||
_context_window_error = True
|
_context_window_error = True
|
||||||
|
_potential_error_str += (
|
||||||
|
"Model={}, Max Input Tokens={}, Got={}".format(
|
||||||
|
model, model_info["max_input_tokens"], input_tokens
|
||||||
|
)
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
|
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
|
||||||
|
@ -4440,7 +4476,9 @@ class Router:
|
||||||
)
|
)
|
||||||
elif _context_window_error == True:
|
elif _context_window_error == True:
|
||||||
raise litellm.ContextWindowExceededError(
|
raise litellm.ContextWindowExceededError(
|
||||||
message="Context Window exceeded for given call",
|
message="litellm._pre_call_checks: Context Window exceeded for given call. No models have context window large enough for this call.\n{}".format(
|
||||||
|
_potential_error_str
|
||||||
|
),
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider="",
|
llm_provider="",
|
||||||
response=httpx.Response(
|
response=httpx.Response(
|
||||||
|
|
|
@ -755,6 +755,7 @@ def test_router_context_window_check_pre_call_check_in_group():
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
"api_base": os.getenv("AZURE_API_BASE"),
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
"base_model": "azure/gpt-35-turbo",
|
"base_model": "azure/gpt-35-turbo",
|
||||||
|
"mock_response": "Hello world 1!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -762,6 +763,7 @@ def test_router_context_window_check_pre_call_check_in_group():
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
"model": "gpt-3.5-turbo-1106",
|
"model": "gpt-3.5-turbo-1106",
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
"mock_response": "Hello world 2!",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
@ -777,6 +779,9 @@ def test_router_context_window_check_pre_call_check_in_group():
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
assert response.choices[0].message.content == "Hello world 2!"
|
||||||
|
assert False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
|
pytest.fail(f"Got unexpected exception on router! - {str(e)}")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue