fix(router.py): support batch completions fastest response streaming

This commit is contained in:
Krrish Dholakia 2024-05-28 21:51:09 -07:00
parent f168e35629
commit e3000504f9
2 changed files with 39 additions and 3 deletions

View file

@ -771,13 +771,13 @@ class Router:
models = [m.strip() for m in model.split(",")] models = [m.strip() for m in model.split(",")]
async def _async_completion_no_exceptions( async def _async_completion_no_exceptions(
model: str, messages: List[Dict[str, str]], **kwargs: Any model: str, messages: List[Dict[str, str]], stream: bool, **kwargs: Any
) -> Union[ModelResponse, CustomStreamWrapper, Exception]: ) -> Union[ModelResponse, CustomStreamWrapper, Exception]:
""" """
Wrapper around self.acompletion that catches exceptions and returns them as a result Wrapper around self.acompletion that catches exceptions and returns them as a result
""" """
try: try:
return await self.acompletion(model=model, messages=messages, **kwargs) return await self.acompletion(model=model, messages=messages, stream=stream, **kwargs) # type: ignore
except asyncio.CancelledError: except asyncio.CancelledError:
verbose_router_logger.debug( verbose_router_logger.debug(
"Received 'task.cancel'. Cancelling call w/ model={}.".format(model) "Received 'task.cancel'. Cancelling call w/ model={}.".format(model)
@ -813,7 +813,7 @@ class Router:
for model in models: for model in models:
task = asyncio.create_task( task = asyncio.create_task(
_async_completion_no_exceptions( _async_completion_no_exceptions(
model=model, messages=messages, **kwargs model=model, messages=messages, stream=stream, **kwargs
) )
) )
pending_tasks.append(task) pending_tasks.append(task)

View file

@ -114,3 +114,39 @@ async def test_batch_completion_fastest_response_unit_test():
assert response._hidden_params["model_id"] == "2" assert response._hidden_params["model_id"] == "2"
assert response.choices[0].message.content == "This is a fake response" assert response.choices[0].message.content == "This is a fake response"
print(f"response: {response}") print(f"response: {response}")
@pytest.mark.asyncio
async def test_batch_completion_fastest_response_streaming():
litellm.set_verbose = True
router = litellm.Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
},
},
{
"model_name": "groq-llama",
"litellm_params": {
"model": "groq/llama3-8b-8192",
},
},
]
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
response = await router.abatch_completion_fastest_response(
model="gpt-3.5-turbo, groq-llama",
messages=[
{"role": "user", "content": "is litellm becoming a better product ?"}
],
max_tokens=15,
stream=True,
)
async for chunk in response:
ChatCompletionChunk.model_validate(chunk.model_dump(), strict=True)