From b4e3a77ad0b823fb5ab44f6ee92a48e2b929993d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 25 Jul 2024 16:47:32 -0700 Subject: [PATCH] feat(utils.py): support sync streaming for custom llm provider --- litellm/__init__.py | 1 + litellm/llms/custom_llm.py | 19 ++++-- litellm/main.py | 8 +++ litellm/tests/test_custom_llm.py | 111 +++++++++++++++++++++++++++++-- litellm/utils.py | 10 ++- 5 files changed, 139 insertions(+), 10 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 0527ef199..b6aacad1a 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -913,6 +913,7 @@ adapters: List[AdapterItem] = [] ### CUSTOM LLMs ### from .types.llms.custom_llm import CustomLLMItem +from .types.utils import GenericStreamingChunk custom_provider_map: List[CustomLLMItem] = [] _custom_providers: List[str] = ( diff --git a/litellm/llms/custom_llm.py b/litellm/llms/custom_llm.py index 5e9933194..f00d02ab7 100644 --- a/litellm/llms/custom_llm.py +++ b/litellm/llms/custom_llm.py @@ -15,7 +15,17 @@ import time import types from enum import Enum from functools import partial -from typing import Callable, List, Literal, Optional, Tuple, Union +from typing import ( + Any, + AsyncIterator, + Callable, + Iterator, + List, + Literal, + Optional, + Tuple, + Union, +) import httpx # type: ignore import requests # type: ignore @@ -23,8 +33,7 @@ import requests # type: ignore import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.types.llms.databricks import GenericStreamingChunk -from litellm.types.utils import ProviderField +from litellm.types.utils import GenericStreamingChunk, ProviderField from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage from .base import BaseLLM @@ -51,13 +60,13 @@ class CustomLLM(BaseLLM): def completion(self, *args, **kwargs) -> ModelResponse: raise CustomLLMError(status_code=500, message="Not implemented yet!") - def streaming(self, *args, **kwargs): + def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: raise CustomLLMError(status_code=500, message="Not implemented yet!") async def acompletion(self, *args, **kwargs) -> ModelResponse: raise CustomLLMError(status_code=500, message="Not implemented yet!") - async def astreaming(self, *args, **kwargs): + async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: raise CustomLLMError(status_code=500, message="Not implemented yet!") diff --git a/litellm/main.py b/litellm/main.py index 51e7c611c..c3be01373 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2713,6 +2713,14 @@ def completion( ## CALL FUNCTION response = handler_fn() + if stream is True: + return CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider=custom_llm_provider, + logging_obj=logging, + ) + else: raise ValueError( f"Unable to map your input to a model. Check your input - {args}" diff --git a/litellm/tests/test_custom_llm.py b/litellm/tests/test_custom_llm.py index fd46c892e..4cc355e4b 100644 --- a/litellm/tests/test_custom_llm.py +++ b/litellm/tests/test_custom_llm.py @@ -17,13 +17,80 @@ sys.path.insert( import os from collections import defaultdict from concurrent.futures import ThreadPoolExecutor +from typing import Any, AsyncIterator, Iterator, Union from unittest.mock import AsyncMock, MagicMock, patch import httpx from dotenv import load_dotenv import litellm -from litellm import CustomLLM, acompletion, completion, get_llm_provider +from litellm import ( + ChatCompletionDeltaChunk, + ChatCompletionUsageBlock, + CustomLLM, + GenericStreamingChunk, + ModelResponse, + acompletion, + completion, + get_llm_provider, +) +from litellm.utils import ModelResponseIterator + + +class CustomModelResponseIterator: + def __init__(self, streaming_response: Union[Iterator, AsyncIterator]): + self.streaming_response = streaming_response + + def chunk_parser(self, chunk: Any) -> GenericStreamingChunk: + return GenericStreamingChunk( + text="hello world", + tool_use=None, + is_finished=True, + finish_reason="stop", + usage=ChatCompletionUsageBlock( + prompt_tokens=10, completion_tokens=20, total_tokens=30 + ), + index=0, + ) + + # Sync iterator + def __iter__(self): + return self + + def __next__(self) -> GenericStreamingChunk: + try: + chunk: Any = self.streaming_response.__next__() # type: ignore + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + return self.chunk_parser(chunk=chunk) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore + return self + + async def __anext__(self) -> GenericStreamingChunk: + try: + chunk = await self.async_response_iterator.__anext__() + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + return self.chunk_parser(chunk=chunk) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") class MyCustomLLM(CustomLLM): @@ -34,8 +101,6 @@ class MyCustomLLM(CustomLLM): mock_response="Hi!", ) # type: ignore - -class MyCustomAsyncLLM(CustomLLM): async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse: return litellm.completion( model="gpt-3.5-turbo", @@ -43,8 +108,27 @@ class MyCustomAsyncLLM(CustomLLM): mock_response="Hi!", ) # type: ignore + def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: + generic_streaming_chunk: GenericStreamingChunk = { + "finish_reason": "stop", + "index": 0, + "is_finished": True, + "text": "Hello world", + "tool_use": None, + "usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30}, + } + + completion_stream = ModelResponseIterator( + model_response=generic_streaming_chunk # type: ignore + ) + custom_iterator = CustomModelResponseIterator( + streaming_response=completion_stream + ) + return custom_iterator + def test_get_llm_provider(): + """""" from litellm.utils import custom_llm_setup my_custom_llm = MyCustomLLM() @@ -74,7 +158,7 @@ def test_simple_completion(): @pytest.mark.asyncio async def test_simple_acompletion(): - my_custom_llm = MyCustomAsyncLLM() + my_custom_llm = MyCustomLLM() litellm.custom_provider_map = [ {"provider": "custom_llm", "custom_handler": my_custom_llm} ] @@ -84,3 +168,22 @@ async def test_simple_acompletion(): ) assert resp.choices[0].message.content == "Hi!" + + +def test_simple_completion_streaming(): + my_custom_llm = MyCustomLLM() + litellm.custom_provider_map = [ + {"provider": "custom_llm", "custom_handler": my_custom_llm} + ] + resp = completion( + model="custom_llm/my-fake-model", + messages=[{"role": "user", "content": "Hello world!"}], + stream=True, + ) + + for chunk in resp: + print(chunk) + if chunk.choices[0].finish_reason is None: + assert isinstance(chunk.choices[0].delta.content, str) + else: + assert chunk.choices[0].finish_reason == "stop" diff --git a/litellm/utils.py b/litellm/utils.py index 0f1b0315d..c14ab36dd 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9262,7 +9262,10 @@ class CustomStreamWrapper: try: # return this for all models completion_obj = {"content": ""} - if self.custom_llm_provider and self.custom_llm_provider == "anthropic": + if self.custom_llm_provider and ( + self.custom_llm_provider == "anthropic" + or self.custom_llm_provider in litellm._custom_providers + ): from litellm.types.utils import GenericStreamingChunk as GChunk if self.received_finish_reason is not None: @@ -10981,3 +10984,8 @@ class ModelResponseIterator: raise StopAsyncIteration self.is_done = True return self.model_response + + +class CustomModelResponseIterator(Iterable): + def __init__(self) -> None: + super().__init__()