feat(utils.py): support sync streaming for custom llm provider

This commit is contained in:
Krrish Dholakia 2024-07-25 16:47:32 -07:00
parent fe503386ab
commit bf23aac11d
5 changed files with 139 additions and 10 deletions

View file

@ -913,6 +913,7 @@ adapters: List[AdapterItem] = []
### CUSTOM LLMs ### ### CUSTOM LLMs ###
from .types.llms.custom_llm import CustomLLMItem from .types.llms.custom_llm import CustomLLMItem
from .types.utils import GenericStreamingChunk
custom_provider_map: List[CustomLLMItem] = [] custom_provider_map: List[CustomLLMItem] = []
_custom_providers: List[str] = ( _custom_providers: List[str] = (

View file

@ -15,7 +15,17 @@ import time
import types import types
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Callable, List, Literal, Optional, Tuple, Union from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
List,
Literal,
Optional,
Tuple,
Union,
)
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
@ -23,8 +33,7 @@ import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.databricks import GenericStreamingChunk from litellm.types.utils import GenericStreamingChunk, ProviderField
from litellm.types.utils import ProviderField
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM from .base import BaseLLM
@ -51,13 +60,13 @@ class CustomLLM(BaseLLM):
def completion(self, *args, **kwargs) -> ModelResponse: def completion(self, *args, **kwargs) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!") raise CustomLLMError(status_code=500, message="Not implemented yet!")
def streaming(self, *args, **kwargs): def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!") raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def acompletion(self, *args, **kwargs) -> ModelResponse: async def acompletion(self, *args, **kwargs) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!") raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def astreaming(self, *args, **kwargs): async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!") raise CustomLLMError(status_code=500, message="Not implemented yet!")

View file

@ -2713,6 +2713,14 @@ def completion(
## CALL FUNCTION ## CALL FUNCTION
response = handler_fn() response = handler_fn()
if stream is True:
return CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging,
)
else: else:
raise ValueError( raise ValueError(
f"Unable to map your input to a model. Check your input - {args}" f"Unable to map your input to a model. Check your input - {args}"

View file

@ -17,13 +17,80 @@ sys.path.insert(
import os import os
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncIterator, Iterator, Union
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import httpx import httpx
from dotenv import load_dotenv from dotenv import load_dotenv
import litellm import litellm
from litellm import CustomLLM, acompletion, completion, get_llm_provider from litellm import (
ChatCompletionDeltaChunk,
ChatCompletionUsageBlock,
CustomLLM,
GenericStreamingChunk,
ModelResponse,
acompletion,
completion,
get_llm_provider,
)
from litellm.utils import ModelResponseIterator
class CustomModelResponseIterator:
def __init__(self, streaming_response: Union[Iterator, AsyncIterator]):
self.streaming_response = streaming_response
def chunk_parser(self, chunk: Any) -> GenericStreamingChunk:
return GenericStreamingChunk(
text="hello world",
tool_use=None,
is_finished=True,
finish_reason="stop",
usage=ChatCompletionUsageBlock(
prompt_tokens=10, completion_tokens=20, total_tokens=30
),
index=0,
)
# Sync iterator
def __iter__(self):
return self
def __next__(self) -> GenericStreamingChunk:
try:
chunk: Any = self.streaming_response.__next__() # type: ignore
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.chunk_parser(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
return self
async def __anext__(self) -> GenericStreamingChunk:
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.chunk_parser(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
class MyCustomLLM(CustomLLM): class MyCustomLLM(CustomLLM):
@ -34,8 +101,6 @@ class MyCustomLLM(CustomLLM):
mock_response="Hi!", mock_response="Hi!",
) # type: ignore ) # type: ignore
class MyCustomAsyncLLM(CustomLLM):
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse: async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion( return litellm.completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
@ -43,8 +108,27 @@ class MyCustomAsyncLLM(CustomLLM):
mock_response="Hi!", mock_response="Hi!",
) # type: ignore ) # type: ignore
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": "Hello world",
"tool_use": None,
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
}
completion_stream = ModelResponseIterator(
model_response=generic_streaming_chunk # type: ignore
)
custom_iterator = CustomModelResponseIterator(
streaming_response=completion_stream
)
return custom_iterator
def test_get_llm_provider(): def test_get_llm_provider():
""""""
from litellm.utils import custom_llm_setup from litellm.utils import custom_llm_setup
my_custom_llm = MyCustomLLM() my_custom_llm = MyCustomLLM()
@ -74,7 +158,7 @@ def test_simple_completion():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_simple_acompletion(): async def test_simple_acompletion():
my_custom_llm = MyCustomAsyncLLM() my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [ litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm} {"provider": "custom_llm", "custom_handler": my_custom_llm}
] ]
@ -84,3 +168,22 @@ async def test_simple_acompletion():
) )
assert resp.choices[0].message.content == "Hi!" assert resp.choices[0].message.content == "Hi!"
def test_simple_completion_streaming():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = completion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
stream=True,
)
for chunk in resp:
print(chunk)
if chunk.choices[0].finish_reason is None:
assert isinstance(chunk.choices[0].delta.content, str)
else:
assert chunk.choices[0].finish_reason == "stop"

View file

@ -9262,7 +9262,10 @@ class CustomStreamWrapper:
try: try:
# return this for all models # return this for all models
completion_obj = {"content": ""} completion_obj = {"content": ""}
if self.custom_llm_provider and self.custom_llm_provider == "anthropic": if self.custom_llm_provider and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
):
from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.types.utils import GenericStreamingChunk as GChunk
if self.received_finish_reason is not None: if self.received_finish_reason is not None:
@ -10981,3 +10984,8 @@ class ModelResponseIterator:
raise StopAsyncIteration raise StopAsyncIteration
self.is_done = True self.is_done = True
return self.model_response return self.model_response
class CustomModelResponseIterator(Iterable):
def __init__(self) -> None:
super().__init__()