fix(router.py): remove wrapping of router.completion() let clients handle this

This commit is contained in:
Krrish Dholakia 2024-01-30 21:11:55 -08:00
parent 4219fe02d7
commit a07f3ec2d4
2 changed files with 88 additions and 5 deletions

View file

@ -289,11 +289,7 @@ class Router:
timeout = kwargs.get("request_timeout", self.timeout) timeout = kwargs.get("request_timeout", self.timeout)
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model}) kwargs.setdefault("metadata", {}).update({"model_group": model})
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: response = self.function_with_fallbacks(**kwargs)
# Submit the function to the executor with a timeout
future = executor.submit(self.function_with_fallbacks, **kwargs)
response = future.result(timeout=timeout) # type: ignore
return response return response
except Exception as e: except Exception as e:
raise e raise e

View file

@ -0,0 +1,87 @@
#### What this tests ####
# This tests if the router timeout error handling during fallbacks
import sys, os, time
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
import litellm
from litellm import Router
from dotenv import load_dotenv
load_dotenv()
def test_router_timeouts():
# Model list for OpenAI and Anthropic models
model_list = [
{
"model_name": "openai-gpt-4",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": "os.environ/AZURE_API_KEY",
"api_base": "os.environ/AZURE_API_BASE",
"api_version": "os.environ/AZURE_API_VERSION",
},
"tpm": 80000,
},
{
"model_name": "anthropic-claude-instant-1.2",
"litellm_params": {
"model": "claude-instant-1",
"api_key": "os.environ/ANTHROPIC_API_KEY",
},
"tpm": 20000,
},
]
fallbacks_list = [
{"openai-gpt-4": ["anthropic-claude-instant-1.2"]},
]
# Configure router
router = Router(
model_list=model_list,
fallbacks=fallbacks_list,
routing_strategy="usage-based-routing",
debug_level="INFO",
set_verbose=True,
redis_host=os.getenv("REDIS_HOST"),
redis_password=os.getenv("REDIS_PASSWORD"),
redis_port=int(os.getenv("REDIS_PORT")),
timeout=10,
)
print("***** TPM SETTINGS *****")
for model_object in model_list:
print(f"{model_object['model_name']}: {model_object['tpm']} TPM")
# Sample list of questions
questions_list = [
{"content": "Tell me a very long joke.", "modality": "voice"},
]
total_tokens_used = 0
# Process each question
for question in questions_list:
messages = [{"content": question["content"], "role": "user"}]
prompt_tokens = litellm.token_counter(text=question["content"], model="gpt-4")
print("prompt_tokens = ", prompt_tokens)
response = router.completion(
model="openai-gpt-4", messages=messages, timeout=5, num_retries=0
)
total_tokens_used += response.usage.total_tokens
print("Response:", response)
print("********** TOKENS USED SO FAR = ", total_tokens_used)