feat(utils.py): support async streaming for custom llm provider

This commit is contained in:
Krrish Dholakia 2024-07-25 17:11:57 -07:00
parent b4e3a77ad0
commit 060249c7e0
3 changed files with 38 additions and 2 deletions

View file

@ -17,8 +17,10 @@ from enum import Enum
from functools import partial from functools import partial
from typing import ( from typing import (
Any, Any,
AsyncGenerator,
AsyncIterator, AsyncIterator,
Callable, Callable,
Coroutine,
Iterator, Iterator,
List, List,
Literal, Literal,

View file

@ -17,7 +17,7 @@ sys.path.insert(
import os import os
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncIterator, Iterator, Union from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Iterator, Union
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import httpx import httpx
@ -75,7 +75,7 @@ class CustomModelResponseIterator:
# Async iterator # Async iterator
def __aiter__(self): def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
return self return self.streaming_response
async def __anext__(self) -> GenericStreamingChunk: async def __anext__(self) -> GenericStreamingChunk:
try: try:
@ -126,6 +126,18 @@ class MyCustomLLM(CustomLLM):
) )
return custom_iterator return custom_iterator
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": "Hello world",
"tool_use": None,
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
}
yield generic_streaming_chunk # type: ignore
def test_get_llm_provider(): def test_get_llm_provider():
"""""" """"""
@ -187,3 +199,23 @@ def test_simple_completion_streaming():
assert isinstance(chunk.choices[0].delta.content, str) assert isinstance(chunk.choices[0].delta.content, str)
else: else:
assert chunk.choices[0].finish_reason == "stop" assert chunk.choices[0].finish_reason == "stop"
@pytest.mark.asyncio
async def test_simple_completion_async_streaming():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = await litellm.acompletion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
stream=True,
)
async for chunk in resp:
print(chunk)
if chunk.choices[0].finish_reason is None:
assert isinstance(chunk.choices[0].delta.content, str)
else:
assert chunk.choices[0].finish_reason == "stop"

View file

@ -10132,6 +10132,7 @@ class CustomStreamWrapper:
try: try:
if self.completion_stream is None: if self.completion_stream is None:
await self.fetch_stream() await self.fetch_stream()
if ( if (
self.custom_llm_provider == "openai" self.custom_llm_provider == "openai"
or self.custom_llm_provider == "azure" or self.custom_llm_provider == "azure"
@ -10156,6 +10157,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "triton" or self.custom_llm_provider == "triton"
or self.custom_llm_provider == "watsonx" or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider in litellm.openai_compatible_endpoints or self.custom_llm_provider in litellm.openai_compatible_endpoints
or self.custom_llm_provider in litellm._custom_providers
): ):
async for chunk in self.completion_stream: async for chunk in self.completion_stream:
print_verbose(f"value of async chunk: {chunk}") print_verbose(f"value of async chunk: {chunk}")