fix(router.py): fix pre call check logic

This commit is contained in:
Krrish Dholakia 2024-03-23 18:56:08 -07:00
parent eb3ca85d7e
commit b7321ae4ee
3 changed files with 11 additions and 4 deletions

View file

@ -572,6 +572,7 @@ def completion(
"ttl", "ttl",
"cache", "cache",
"no-log", "no-log",
"base_model",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {

View file

@ -2175,13 +2175,19 @@ class Router:
try: try:
input_tokens = litellm.token_counter(messages=messages) input_tokens = litellm.token_counter(messages=messages)
except: except Exception as e:
return _returned_deployments return _returned_deployments
for idx, deployment in enumerate(_returned_deployments): for idx, deployment in enumerate(_returned_deployments):
# see if we have the info for this model # see if we have the info for this model
try: try:
model_info = litellm.get_model_info(model=deployment["model_name"]) base_model = deployment.get("litellm_params", {}).get(
"base_model", None
)
model = base_model or deployment.get("litellm_params", {}).get(
"model", None
)
model_info = litellm.get_model_info(model=model)
except: except:
continue continue

View file

@ -319,6 +319,7 @@ def test_router_context_window_check():
"api_key": os.getenv("AZURE_API_KEY"), "api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"), "api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"), "api_base": os.getenv("AZURE_API_BASE"),
"base_model": "azure/gpt-35-turbo",
}, },
}, },
{ {
@ -330,7 +331,7 @@ def test_router_context_window_check():
}, },
] ]
router = Router(model_list=model_list, set_verbose=True, enable_pre_call_checks=True) # type: ignore router = Router(model_list=model_list, set_verbose=True, enable_pre_call_checks=True, num_retries=0) # type: ignore
response = router.completion( response = router.completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
@ -341,7 +342,6 @@ def test_router_context_window_check():
) )
print(f"response: {response}") print(f"response: {response}")
raise Exception("it worked!")
except Exception as e: except Exception as e:
pytest.fail(f"Got unexpected exception on router! - {str(e)}") pytest.fail(f"Got unexpected exception on router! - {str(e)}")