(fix) router - timeout exception mapping

This commit is contained in:
ishaan-jaff 2024-01-19 20:30:31 -08:00
parent 7cf0bb475f
commit 84684c50fa
2 changed files with 19 additions and 11 deletions

View file

@ -285,7 +285,7 @@ class Router:
"messages": messages, "messages": messages,
"functions": functions, "functions": functions,
"function_call": function_call, "function_call": function_call,
"timeout": timeout, "timeout": timeout or self.timeout,
"temperature": temperature, "temperature": temperature,
"top_p": top_p, "top_p": top_p,
"n": n, "n": n,
@ -316,7 +316,7 @@ class Router:
future = executor.submit( future = executor.submit(
self.function_with_fallbacks, **kwargs, **completion_kwargs self.function_with_fallbacks, **kwargs, **completion_kwargs
) )
response = future.result(timeout=timeout) # type: ignore response = future.result() # type: ignore
return response return response
except Exception as e: except Exception as e:
@ -417,7 +417,7 @@ class Router:
"messages": messages, "messages": messages,
"functions": functions, "functions": functions,
"function_call": function_call, "function_call": function_call,
"timeout": timeout, "timeout": timeout or self.timeout,
"temperature": temperature, "temperature": temperature,
"top_p": top_p, "top_p": top_p,
"n": n, "n": n,
@ -442,7 +442,6 @@ class Router:
"original_function": self._acompletion, "original_function": self._acompletion,
} }
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model}) kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks( response = await self.async_function_with_fallbacks(

View file

@ -39,6 +39,8 @@ def test_timeout():
def test_hanging_request_azure(): def test_hanging_request_azure():
litellm.set_verbose = True litellm.set_verbose = True
import asyncio
try: try:
router = litellm.Router( router = litellm.Router(
model_list=[ model_list=[
@ -58,13 +60,20 @@ def test_hanging_request_azure():
) )
encoded = litellm.utils.encode(model="gpt-3.5-turbo", text="blue")[0] encoded = litellm.utils.encode(model="gpt-3.5-turbo", text="blue")[0]
response = router.completion(
async def _test():
response = await router.acompletion(
model="azure-gpt", model="azure-gpt",
messages=[{"role": "user", "content": f"what color is red {uuid.uuid4()}"}], messages=[
{"role": "user", "content": f"what color is red {uuid.uuid4()}"}
],
logit_bias={encoded: 100}, logit_bias={encoded: 100},
timeout=0.01, timeout=0.01,
) )
print(response) print(response)
return response
response = asyncio.run(_test())
if response.choices[0].message.content is not None: if response.choices[0].message.content is not None:
pytest.fail("Got a response, expected a timeout") pytest.fail("Got a response, expected a timeout")