forked from phoenix/litellm-mirror
feat(utils.py): support async streaming for custom llm provider
This commit is contained in:
parent
b4e3a77ad0
commit
060249c7e0
3 changed files with 38 additions and 2 deletions
|
@ -17,8 +17,10 @@ from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
|
Coroutine,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
|
|
|
@ -17,7 +17,7 @@ sys.path.insert(
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any, AsyncIterator, Iterator, Union
|
from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Iterator, Union
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -75,7 +75,7 @@ class CustomModelResponseIterator:
|
||||||
# Async iterator
|
# Async iterator
|
||||||
def __aiter__(self):
|
def __aiter__(self):
|
||||||
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
|
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
|
||||||
return self
|
return self.streaming_response
|
||||||
|
|
||||||
async def __anext__(self) -> GenericStreamingChunk:
|
async def __anext__(self) -> GenericStreamingChunk:
|
||||||
try:
|
try:
|
||||||
|
@ -126,6 +126,18 @@ class MyCustomLLM(CustomLLM):
|
||||||
)
|
)
|
||||||
return custom_iterator
|
return custom_iterator
|
||||||
|
|
||||||
|
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
|
||||||
|
generic_streaming_chunk: GenericStreamingChunk = {
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"is_finished": True,
|
||||||
|
"text": "Hello world",
|
||||||
|
"tool_use": None,
|
||||||
|
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
|
||||||
|
}
|
||||||
|
|
||||||
|
yield generic_streaming_chunk # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_get_llm_provider():
|
def test_get_llm_provider():
|
||||||
""""""
|
""""""
|
||||||
|
@ -187,3 +199,23 @@ def test_simple_completion_streaming():
|
||||||
assert isinstance(chunk.choices[0].delta.content, str)
|
assert isinstance(chunk.choices[0].delta.content, str)
|
||||||
else:
|
else:
|
||||||
assert chunk.choices[0].finish_reason == "stop"
|
assert chunk.choices[0].finish_reason == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_simple_completion_async_streaming():
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
resp = await litellm.acompletion(
|
||||||
|
model="custom_llm/my-fake-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello world!"}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in resp:
|
||||||
|
print(chunk)
|
||||||
|
if chunk.choices[0].finish_reason is None:
|
||||||
|
assert isinstance(chunk.choices[0].delta.content, str)
|
||||||
|
else:
|
||||||
|
assert chunk.choices[0].finish_reason == "stop"
|
||||||
|
|
|
@ -10132,6 +10132,7 @@ class CustomStreamWrapper:
|
||||||
try:
|
try:
|
||||||
if self.completion_stream is None:
|
if self.completion_stream is None:
|
||||||
await self.fetch_stream()
|
await self.fetch_stream()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.custom_llm_provider == "openai"
|
self.custom_llm_provider == "openai"
|
||||||
or self.custom_llm_provider == "azure"
|
or self.custom_llm_provider == "azure"
|
||||||
|
@ -10156,6 +10157,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "triton"
|
or self.custom_llm_provider == "triton"
|
||||||
or self.custom_llm_provider == "watsonx"
|
or self.custom_llm_provider == "watsonx"
|
||||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||||
|
or self.custom_llm_provider in litellm._custom_providers
|
||||||
):
|
):
|
||||||
async for chunk in self.completion_stream:
|
async for chunk in self.completion_stream:
|
||||||
print_verbose(f"value of async chunk: {chunk}")
|
print_verbose(f"value of async chunk: {chunk}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue