forked from phoenix/litellm-mirror
fix(router.py): support batch completions fastest response streaming
This commit is contained in:
parent
f168e35629
commit
e3000504f9
2 changed files with 39 additions and 3 deletions
|
@ -771,13 +771,13 @@ class Router:
|
||||||
models = [m.strip() for m in model.split(",")]
|
models = [m.strip() for m in model.split(",")]
|
||||||
|
|
||||||
async def _async_completion_no_exceptions(
|
async def _async_completion_no_exceptions(
|
||||||
model: str, messages: List[Dict[str, str]], **kwargs: Any
|
model: str, messages: List[Dict[str, str]], stream: bool, **kwargs: Any
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper, Exception]:
|
) -> Union[ModelResponse, CustomStreamWrapper, Exception]:
|
||||||
"""
|
"""
|
||||||
Wrapper around self.acompletion that catches exceptions and returns them as a result
|
Wrapper around self.acompletion that catches exceptions and returns them as a result
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return await self.acompletion(model=model, messages=messages, **kwargs)
|
return await self.acompletion(model=model, messages=messages, stream=stream, **kwargs) # type: ignore
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
"Received 'task.cancel'. Cancelling call w/ model={}.".format(model)
|
"Received 'task.cancel'. Cancelling call w/ model={}.".format(model)
|
||||||
|
@ -813,7 +813,7 @@ class Router:
|
||||||
for model in models:
|
for model in models:
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
_async_completion_no_exceptions(
|
_async_completion_no_exceptions(
|
||||||
model=model, messages=messages, **kwargs
|
model=model, messages=messages, stream=stream, **kwargs
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
pending_tasks.append(task)
|
pending_tasks.append(task)
|
||||||
|
|
|
@ -114,3 +114,39 @@ async def test_batch_completion_fastest_response_unit_test():
|
||||||
assert response._hidden_params["model_id"] == "2"
|
assert response._hidden_params["model_id"] == "2"
|
||||||
assert response.choices[0].message.content == "This is a fake response"
|
assert response.choices[0].message.content == "This is a fake response"
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_completion_fastest_response_streaming():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "groq-llama",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "groq/llama3-8b-8192",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
|
|
||||||
|
response = await router.abatch_completion_fastest_response(
|
||||||
|
model="gpt-3.5-turbo, groq-llama",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "is litellm becoming a better product ?"}
|
||||||
|
],
|
||||||
|
max_tokens=15,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
ChatCompletionChunk.model_validate(chunk.model_dump(), strict=True)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue