diff --git a/litellm/llms/custom_llm.py b/litellm/llms/custom_llm.py index f00d02ab7..f1b2b28b4 100644 --- a/litellm/llms/custom_llm.py +++ b/litellm/llms/custom_llm.py @@ -17,8 +17,10 @@ from enum import Enum from functools import partial from typing import ( Any, + AsyncGenerator, AsyncIterator, Callable, + Coroutine, Iterator, List, Literal, diff --git a/litellm/tests/test_custom_llm.py b/litellm/tests/test_custom_llm.py index 4cc355e4b..af88b1f3a 100644 --- a/litellm/tests/test_custom_llm.py +++ b/litellm/tests/test_custom_llm.py @@ -17,7 +17,7 @@ sys.path.insert( import os from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from typing import Any, AsyncIterator, Iterator, Union +from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Iterator, Union from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -75,7 +75,7 @@ class CustomModelResponseIterator: # Async iterator def __aiter__(self): self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore - return self + return self.streaming_response async def __anext__(self) -> GenericStreamingChunk: try: @@ -126,6 +126,18 @@ class MyCustomLLM(CustomLLM): ) return custom_iterator + async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: # type: ignore + generic_streaming_chunk: GenericStreamingChunk = { + "finish_reason": "stop", + "index": 0, + "is_finished": True, + "text": "Hello world", + "tool_use": None, + "usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30}, + } + + yield generic_streaming_chunk # type: ignore + def test_get_llm_provider(): """""" @@ -187,3 +199,23 @@ def test_simple_completion_streaming(): assert isinstance(chunk.choices[0].delta.content, str) else: assert chunk.choices[0].finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_simple_completion_async_streaming(): + my_custom_llm = MyCustomLLM() + litellm.custom_provider_map = [ + {"provider": "custom_llm", "custom_handler": my_custom_llm} + ] + resp = await litellm.acompletion( + model="custom_llm/my-fake-model", + messages=[{"role": "user", "content": "Hello world!"}], + stream=True, + ) + + async for chunk in resp: + print(chunk) + if chunk.choices[0].finish_reason is None: + assert isinstance(chunk.choices[0].delta.content, str) + else: + assert chunk.choices[0].finish_reason == "stop" diff --git a/litellm/utils.py b/litellm/utils.py index c14ab36dd..9158afb74 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10132,6 +10132,7 @@ class CustomStreamWrapper: try: if self.completion_stream is None: await self.fetch_stream() + if ( self.custom_llm_provider == "openai" or self.custom_llm_provider == "azure" @@ -10156,6 +10157,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "triton" or self.custom_llm_provider == "watsonx" or self.custom_llm_provider in litellm.openai_compatible_endpoints + or self.custom_llm_provider in litellm._custom_providers ): async for chunk in self.completion_stream: print_verbose(f"value of async chunk: {chunk}")