feat(utils.py): support sync streaming for custom llm provider

This commit is contained in:
Krrish Dholakia 2024-07-25 16:47:32 -07:00
parent 9f97436308
commit b4e3a77ad0
5 changed files with 139 additions and 10 deletions

View file

@ -17,13 +17,80 @@ sys.path.insert(
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncIterator, Iterator, Union
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
from dotenv import load_dotenv
import litellm
from litellm import CustomLLM, acompletion, completion, get_llm_provider
from litellm import (
ChatCompletionDeltaChunk,
ChatCompletionUsageBlock,
CustomLLM,
GenericStreamingChunk,
ModelResponse,
acompletion,
completion,
get_llm_provider,
)
from litellm.utils import ModelResponseIterator
class CustomModelResponseIterator:
def __init__(self, streaming_response: Union[Iterator, AsyncIterator]):
self.streaming_response = streaming_response
def chunk_parser(self, chunk: Any) -> GenericStreamingChunk:
return GenericStreamingChunk(
text="hello world",
tool_use=None,
is_finished=True,
finish_reason="stop",
usage=ChatCompletionUsageBlock(
prompt_tokens=10, completion_tokens=20, total_tokens=30
),
index=0,
)
# Sync iterator
def __iter__(self):
return self
def __next__(self) -> GenericStreamingChunk:
try:
chunk: Any = self.streaming_response.__next__() # type: ignore
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.chunk_parser(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
return self
async def __anext__(self) -> GenericStreamingChunk:
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.chunk_parser(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
class MyCustomLLM(CustomLLM):
@ -34,8 +101,6 @@ class MyCustomLLM(CustomLLM):
mock_response="Hi!",
) # type: ignore
class MyCustomAsyncLLM(CustomLLM):
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
@ -43,8 +108,27 @@ class MyCustomAsyncLLM(CustomLLM):
mock_response="Hi!",
) # type: ignore
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": "Hello world",
"tool_use": None,
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
}
completion_stream = ModelResponseIterator(
model_response=generic_streaming_chunk # type: ignore
)
custom_iterator = CustomModelResponseIterator(
streaming_response=completion_stream
)
return custom_iterator
def test_get_llm_provider():
""""""
from litellm.utils import custom_llm_setup
my_custom_llm = MyCustomLLM()
@ -74,7 +158,7 @@ def test_simple_completion():
@pytest.mark.asyncio
async def test_simple_acompletion():
my_custom_llm = MyCustomAsyncLLM()
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
@ -84,3 +168,22 @@ async def test_simple_acompletion():
)
assert resp.choices[0].message.content == "Hi!"
def test_simple_completion_streaming():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = completion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
stream=True,
)
for chunk in resp:
print(chunk)
if chunk.choices[0].finish_reason is None:
assert isinstance(chunk.choices[0].delta.content, str)
else:
assert chunk.choices[0].finish_reason == "stop"