feat(utils.py): support sync streaming for custom llm provider

This commit is contained in:
Krrish Dholakia 2024-07-25 16:47:32 -07:00
parent 9f97436308
commit b4e3a77ad0
5 changed files with 139 additions and 10 deletions

View file

@ -913,6 +913,7 @@ adapters: List[AdapterItem] = []
### CUSTOM LLMs ###
from .types.llms.custom_llm import CustomLLMItem
from .types.utils import GenericStreamingChunk
custom_provider_map: List[CustomLLMItem] = []
_custom_providers: List[str] = (

View file

@ -15,7 +15,17 @@ import time
import types
from enum import Enum
from functools import partial
from typing import Callable, List, Literal, Optional, Tuple, Union
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
List,
Literal,
Optional,
Tuple,
Union,
)
import httpx # type: ignore
import requests # type: ignore
@ -23,8 +33,7 @@ import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.types.utils import ProviderField
from litellm.types.utils import GenericStreamingChunk, ProviderField
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM
@ -51,13 +60,13 @@ class CustomLLM(BaseLLM):
def completion(self, *args, **kwargs) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
def streaming(self, *args, **kwargs):
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def acompletion(self, *args, **kwargs) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def astreaming(self, *args, **kwargs):
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")

View file

@ -2713,6 +2713,14 @@ def completion(
## CALL FUNCTION
response = handler_fn()
if stream is True:
return CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging,
)
else:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"

View file

@ -17,13 +17,80 @@ sys.path.insert(
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncIterator, Iterator, Union
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
from dotenv import load_dotenv
import litellm
from litellm import CustomLLM, acompletion, completion, get_llm_provider
from litellm import (
ChatCompletionDeltaChunk,
ChatCompletionUsageBlock,
CustomLLM,
GenericStreamingChunk,
ModelResponse,
acompletion,
completion,
get_llm_provider,
)
from litellm.utils import ModelResponseIterator
class CustomModelResponseIterator:
def __init__(self, streaming_response: Union[Iterator, AsyncIterator]):
self.streaming_response = streaming_response
def chunk_parser(self, chunk: Any) -> GenericStreamingChunk:
return GenericStreamingChunk(
text="hello world",
tool_use=None,
is_finished=True,
finish_reason="stop",
usage=ChatCompletionUsageBlock(
prompt_tokens=10, completion_tokens=20, total_tokens=30
),
index=0,
)
# Sync iterator
def __iter__(self):
return self
def __next__(self) -> GenericStreamingChunk:
try:
chunk: Any = self.streaming_response.__next__() # type: ignore
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.chunk_parser(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
return self
async def __anext__(self) -> GenericStreamingChunk:
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.chunk_parser(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
class MyCustomLLM(CustomLLM):
@ -34,8 +101,6 @@ class MyCustomLLM(CustomLLM):
mock_response="Hi!",
) # type: ignore
class MyCustomAsyncLLM(CustomLLM):
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
@ -43,8 +108,27 @@ class MyCustomAsyncLLM(CustomLLM):
mock_response="Hi!",
) # type: ignore
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": "Hello world",
"tool_use": None,
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
}
completion_stream = ModelResponseIterator(
model_response=generic_streaming_chunk # type: ignore
)
custom_iterator = CustomModelResponseIterator(
streaming_response=completion_stream
)
return custom_iterator
def test_get_llm_provider():
""""""
from litellm.utils import custom_llm_setup
my_custom_llm = MyCustomLLM()
@ -74,7 +158,7 @@ def test_simple_completion():
@pytest.mark.asyncio
async def test_simple_acompletion():
my_custom_llm = MyCustomAsyncLLM()
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
@ -84,3 +168,22 @@ async def test_simple_acompletion():
)
assert resp.choices[0].message.content == "Hi!"
def test_simple_completion_streaming():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = completion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
stream=True,
)
for chunk in resp:
print(chunk)
if chunk.choices[0].finish_reason is None:
assert isinstance(chunk.choices[0].delta.content, str)
else:
assert chunk.choices[0].finish_reason == "stop"

View file

@ -9262,7 +9262,10 @@ class CustomStreamWrapper:
try:
# return this for all models
completion_obj = {"content": ""}
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
if self.custom_llm_provider and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
):
from litellm.types.utils import GenericStreamingChunk as GChunk
if self.received_finish_reason is not None:
@ -10981,3 +10984,8 @@ class ModelResponseIterator:
raise StopAsyncIteration
self.is_done = True
return self.model_response
class CustomModelResponseIterator(Iterable):
def __init__(self) -> None:
super().__init__()