forked from phoenix/litellm-mirror
feat(utils.py): support sync streaming for custom llm provider
This commit is contained in:
parent
9f97436308
commit
b4e3a77ad0
5 changed files with 139 additions and 10 deletions
|
@ -913,6 +913,7 @@ adapters: List[AdapterItem] = []
|
||||||
|
|
||||||
### CUSTOM LLMs ###
|
### CUSTOM LLMs ###
|
||||||
from .types.llms.custom_llm import CustomLLMItem
|
from .types.llms.custom_llm import CustomLLMItem
|
||||||
|
from .types.utils import GenericStreamingChunk
|
||||||
|
|
||||||
custom_provider_map: List[CustomLLMItem] = []
|
custom_provider_map: List[CustomLLMItem] = []
|
||||||
_custom_providers: List[str] = (
|
_custom_providers: List[str] = (
|
||||||
|
|
|
@ -15,7 +15,17 @@ import time
|
||||||
import types
|
import types
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, List, Literal, Optional, Tuple, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
|
@ -23,8 +33,7 @@ import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.types.llms.databricks import GenericStreamingChunk
|
from litellm.types.utils import GenericStreamingChunk, ProviderField
|
||||||
from litellm.types.utils import ProviderField
|
|
||||||
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
|
||||||
|
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
@ -51,13 +60,13 @@ class CustomLLM(BaseLLM):
|
||||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
def streaming(self, *args, **kwargs):
|
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
async def astreaming(self, *args, **kwargs):
|
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2713,6 +2713,14 @@ def completion(
|
||||||
|
|
||||||
## CALL FUNCTION
|
## CALL FUNCTION
|
||||||
response = handler_fn()
|
response = handler_fn()
|
||||||
|
if stream is True:
|
||||||
|
return CustomStreamWrapper(
|
||||||
|
completion_stream=response,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
logging_obj=logging,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unable to map your input to a model. Check your input - {args}"
|
f"Unable to map your input to a model. Check your input - {args}"
|
||||||
|
|
|
@ -17,13 +17,80 @@ sys.path.insert(
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Any, AsyncIterator, Iterator, Union
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import CustomLLM, acompletion, completion, get_llm_provider
|
from litellm import (
|
||||||
|
ChatCompletionDeltaChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
CustomLLM,
|
||||||
|
GenericStreamingChunk,
|
||||||
|
ModelResponse,
|
||||||
|
acompletion,
|
||||||
|
completion,
|
||||||
|
get_llm_provider,
|
||||||
|
)
|
||||||
|
from litellm.utils import ModelResponseIterator
|
||||||
|
|
||||||
|
|
||||||
|
class CustomModelResponseIterator:
|
||||||
|
def __init__(self, streaming_response: Union[Iterator, AsyncIterator]):
|
||||||
|
self.streaming_response = streaming_response
|
||||||
|
|
||||||
|
def chunk_parser(self, chunk: Any) -> GenericStreamingChunk:
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="hello world",
|
||||||
|
tool_use=None,
|
||||||
|
is_finished=True,
|
||||||
|
finish_reason="stop",
|
||||||
|
usage=ChatCompletionUsageBlock(
|
||||||
|
prompt_tokens=10, completion_tokens=20, total_tokens=30
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
chunk: Any = self.streaming_response.__next__() # type: ignore
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.chunk_parser(chunk=chunk)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
chunk = await self.async_response_iterator.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.chunk_parser(chunk=chunk)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||||
|
|
||||||
|
|
||||||
class MyCustomLLM(CustomLLM):
|
class MyCustomLLM(CustomLLM):
|
||||||
|
@ -34,8 +101,6 @@ class MyCustomLLM(CustomLLM):
|
||||||
mock_response="Hi!",
|
mock_response="Hi!",
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class MyCustomAsyncLLM(CustomLLM):
|
|
||||||
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
|
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||||
return litellm.completion(
|
return litellm.completion(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
|
@ -43,8 +108,27 @@ class MyCustomAsyncLLM(CustomLLM):
|
||||||
mock_response="Hi!",
|
mock_response="Hi!",
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||||
|
generic_streaming_chunk: GenericStreamingChunk = {
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"is_finished": True,
|
||||||
|
"text": "Hello world",
|
||||||
|
"tool_use": None,
|
||||||
|
"usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30},
|
||||||
|
}
|
||||||
|
|
||||||
|
completion_stream = ModelResponseIterator(
|
||||||
|
model_response=generic_streaming_chunk # type: ignore
|
||||||
|
)
|
||||||
|
custom_iterator = CustomModelResponseIterator(
|
||||||
|
streaming_response=completion_stream
|
||||||
|
)
|
||||||
|
return custom_iterator
|
||||||
|
|
||||||
|
|
||||||
def test_get_llm_provider():
|
def test_get_llm_provider():
|
||||||
|
""""""
|
||||||
from litellm.utils import custom_llm_setup
|
from litellm.utils import custom_llm_setup
|
||||||
|
|
||||||
my_custom_llm = MyCustomLLM()
|
my_custom_llm = MyCustomLLM()
|
||||||
|
@ -74,7 +158,7 @@ def test_simple_completion():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_simple_acompletion():
|
async def test_simple_acompletion():
|
||||||
my_custom_llm = MyCustomAsyncLLM()
|
my_custom_llm = MyCustomLLM()
|
||||||
litellm.custom_provider_map = [
|
litellm.custom_provider_map = [
|
||||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
]
|
]
|
||||||
|
@ -84,3 +168,22 @@ async def test_simple_acompletion():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resp.choices[0].message.content == "Hi!"
|
assert resp.choices[0].message.content == "Hi!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_completion_streaming():
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
resp = completion(
|
||||||
|
model="custom_llm/my-fake-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello world!"}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in resp:
|
||||||
|
print(chunk)
|
||||||
|
if chunk.choices[0].finish_reason is None:
|
||||||
|
assert isinstance(chunk.choices[0].delta.content, str)
|
||||||
|
else:
|
||||||
|
assert chunk.choices[0].finish_reason == "stop"
|
||||||
|
|
|
@ -9262,7 +9262,10 @@ class CustomStreamWrapper:
|
||||||
try:
|
try:
|
||||||
# return this for all models
|
# return this for all models
|
||||||
completion_obj = {"content": ""}
|
completion_obj = {"content": ""}
|
||||||
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
|
if self.custom_llm_provider and (
|
||||||
|
self.custom_llm_provider == "anthropic"
|
||||||
|
or self.custom_llm_provider in litellm._custom_providers
|
||||||
|
):
|
||||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||||
|
|
||||||
if self.received_finish_reason is not None:
|
if self.received_finish_reason is not None:
|
||||||
|
@ -10981,3 +10984,8 @@ class ModelResponseIterator:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
self.is_done = True
|
self.is_done = True
|
||||||
return self.model_response
|
return self.model_response
|
||||||
|
|
||||||
|
|
||||||
|
class CustomModelResponseIterator(Iterable):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue