mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(utils.py): support sync streaming for custom llm provider
This commit is contained in:
parent
9f97436308
commit
b4e3a77ad0
5 changed files with 139 additions and 10 deletions
|
@ -17,13 +17,80 @@ sys.path.insert(
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, AsyncIterator, Iterator, Union
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm
|
||||
from litellm import CustomLLM, acompletion, completion, get_llm_provider
|
||||
from litellm import (
|
||||
ChatCompletionDeltaChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
CustomLLM,
|
||||
GenericStreamingChunk,
|
||||
ModelResponse,
|
||||
acompletion,
|
||||
completion,
|
||||
get_llm_provider,
|
||||
)
|
||||
from litellm.utils import ModelResponseIterator
|
||||
|
||||
|
||||
class CustomModelResponseIterator:
|
||||
def __init__(self, streaming_response: Union[Iterator, AsyncIterator]):
|
||||
self.streaming_response = streaming_response
|
||||
|
||||
def chunk_parser(self, chunk: Any) -> GenericStreamingChunk:
|
||||
return GenericStreamingChunk(
|
||||
text="hello world",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=ChatCompletionUsageBlock(
|
||||
prompt_tokens=10, completion_tokens=20, total_tokens=30
|
||||
),
|
||||
index=0,
|
||||
)
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> GenericStreamingChunk:
|
||||
try:
|
||||
chunk: Any = self.streaming_response.__next__() # type: ignore
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
return self.chunk_parser(chunk=chunk)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> GenericStreamingChunk:
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
return self.chunk_parser(chunk=chunk)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
|
||||
class MyCustomLLM(CustomLLM):
|
||||
|
@ -34,8 +101,6 @@ class MyCustomLLM(CustomLLM):
|
|||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class MyCustomAsyncLLM(CustomLLM):
|
||||
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
|
@ -43,8 +108,27 @@ class MyCustomAsyncLLM(CustomLLM):
|
|||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"is_finished": True,
|
||||
"text": "Hello world",
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
|
||||
}
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
model_response=generic_streaming_chunk # type: ignore
|
||||
)
|
||||
custom_iterator = CustomModelResponseIterator(
|
||||
streaming_response=completion_stream
|
||||
)
|
||||
return custom_iterator
|
||||
|
||||
|
||||
def test_get_llm_provider():
|
||||
""""""
|
||||
from litellm.utils import custom_llm_setup
|
||||
|
||||
my_custom_llm = MyCustomLLM()
|
||||
|
@ -74,7 +158,7 @@ def test_simple_completion():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_acompletion():
|
||||
my_custom_llm = MyCustomAsyncLLM()
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
|
@ -84,3 +168,22 @@ async def test_simple_acompletion():
|
|||
)
|
||||
|
||||
assert resp.choices[0].message.content == "Hi!"
|
||||
|
||||
|
||||
def test_simple_completion_streaming():
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
resp = completion(
|
||||
model="custom_llm/my-fake-model",
|
||||
messages=[{"role": "user", "content": "Hello world!"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for chunk in resp:
|
||||
print(chunk)
|
||||
if chunk.choices[0].finish_reason is None:
|
||||
assert isinstance(chunk.choices[0].delta.content, str)
|
||||
else:
|
||||
assert chunk.choices[0].finish_reason == "stop"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue