Merge pull request #3944 from BerriAI/litellm_fix_parallel_streaming

fix: fix streaming with httpx client
This commit is contained in:
Krish Dholakia 2024-05-31 21:42:37 -07:00 committed by GitHub
commit e7ff3adc26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 182 additions and 82 deletions

View file

@ -5,6 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*
### INIT VARIABLES ### ### INIT VARIABLES ###
import threading, requests, os import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any, Literal from typing import Callable, List, Optional, Dict, Union, Any, Literal
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.caching import Cache from litellm.caching import Cache
from litellm._logging import ( from litellm._logging import (
set_verbose, set_verbose,
@ -213,6 +214,7 @@ add_function_to_prompt: bool = (
) )
client_session: Optional[httpx.Client] = None client_session: Optional[httpx.Client] = None
aclient_session: Optional[httpx.AsyncClient] = None aclient_session: Optional[httpx.AsyncClient] = None
module_level_aclient = AsyncHTTPHandler()
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
model_cost_map_url: str = ( model_cost_map_url: str = (
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"

View file

@ -1,7 +1,7 @@
# What is this? # What is this?
## Initial implementation of calling bedrock via httpx client (allows for async calls). ## Initial implementation of calling bedrock via httpx client (allows for async calls).
## V1 - covers cohere + anthropic claude-3 support ## V1 - covers cohere + anthropic claude-3 support
from functools import partial
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
@ -145,6 +145,37 @@ class AmazonCohereChatConfig:
return optional_params return optional_params
async def make_call(
client: Optional[AsyncHTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
if client is None:
client = AsyncHTTPHandler() # Create a new client if none provided
response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.text)
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
class BedrockLLM(BaseLLM): class BedrockLLM(BaseLLM):
""" """
Example call Example call
@ -1012,39 +1043,24 @@ class BedrockLLM(BaseLLM):
headers={}, headers={},
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper: ) -> CustomStreamWrapper:
if client is None: # The call is not made here; instead, we prepare the necessary objects for the stream.
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.text)
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
streaming_response = CustomStreamWrapper( streaming_response = CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=None,
make_call=partial(
make_call,
client=client,
api_base=api_base,
headers=headers,
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model, model=model,
custom_llm_provider="bedrock", custom_llm_provider="bedrock",
logging_obj=logging_obj, logging_obj=logging_obj,
) )
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=streaming_response,
additional_args={"complete_input_dict": data},
)
return streaming_response return streaming_response
def embedding(self, *args, **kwargs): def embedding(self, *args, **kwargs):

View file

@ -1,4 +1,4 @@
import httpx, asyncio import httpx, asyncio, traceback
from typing import Optional, Union, Mapping, Any from typing import Optional, Union, Mapping, Any
# https://www.python-httpx.org/advanced/timeouts # https://www.python-httpx.org/advanced/timeouts
@ -48,11 +48,17 @@ class AsyncHTTPHandler:
headers: Optional[dict] = None, headers: Optional[dict] = None,
stream: bool = False, stream: bool = False,
): ):
req = self.client.build_request( try:
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore req = self.client.build_request(
) "POST", url, data=data, json=json, params=params, headers=headers # type: ignore
response = await self.client.send(req, stream=stream) )
return response response = await self.client.send(req, stream=stream)
response.raise_for_status()
return response
except httpx.HTTPStatusError as e:
raise
except Exception as e:
raise
def __del__(self) -> None: def __del__(self) -> None:
try: try:

View file

@ -1,5 +1,6 @@
# What is this? # What is this?
## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request ## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
from functools import partial
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
@ -123,7 +124,7 @@ class DatabricksConfig:
original_chunk = None # this is used for function/tool calling original_chunk = None # this is used for function/tool calling
chunk_data = chunk_data.replace("data:", "") chunk_data = chunk_data.replace("data:", "")
chunk_data = chunk_data.strip() chunk_data = chunk_data.strip()
if len(chunk_data) == 0: if len(chunk_data) == 0 or chunk_data == "[DONE]":
return { return {
"text": "", "text": "",
"is_finished": is_finished, "is_finished": is_finished,
@ -221,6 +222,32 @@ class DatabricksEmbeddingConfig:
return optional_params return optional_params
async def make_call(
client: AsyncHTTPHandler,
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.text)
completion_stream = response.aiter_lines()
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
class DatabricksChatCompletion(BaseLLM): class DatabricksChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -354,29 +381,21 @@ class DatabricksChatCompletion(BaseLLM):
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
headers={}, headers={},
): client: Optional[AsyncHTTPHandler] = None,
self.async_handler = AsyncHTTPHandler( ) -> CustomStreamWrapper:
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
data["stream"] = True data["stream"] = True
try:
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data), stream=True
)
response.raise_for_status()
completion_stream = response.aiter_lines()
except httpx.HTTPStatusError as e:
raise DatabricksError(
status_code=e.response.status_code, message=response.text
)
except httpx.TimeoutException as e:
raise DatabricksError(status_code=408, message="Timeout error occurred.")
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=None,
make_call=partial(
make_call,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model, model=model,
custom_llm_provider="databricks", custom_llm_provider="databricks",
logging_obj=logging_obj, logging_obj=logging_obj,
@ -475,6 +494,8 @@ class DatabricksChatCompletion(BaseLLM):
}, },
) )
if acompletion == True: if acompletion == True:
if client is not None and isinstance(client, HTTPHandler):
client = None
if ( if (
stream is not None and stream == True stream is not None and stream == True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format) ): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
@ -496,6 +517,7 @@ class DatabricksChatCompletion(BaseLLM):
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
headers=headers, headers=headers,
client=client,
) )
else: else:
return self.acompletion_function( return self.acompletion_function(

View file

@ -1,7 +1,7 @@
# What is this? # What is this?
## Controller file for Predibase Integration - https://predibase.com/ ## Controller file for Predibase Integration - https://predibase.com/
from functools import partial
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
@ -51,6 +51,32 @@ class PredibaseError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
async def make_call(
client: AsyncHTTPHandler,
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise PredibaseError(status_code=response.status_code, message=response.text)
completion_stream = response.aiter_lines()
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
class PredibaseConfig: class PredibaseConfig:
""" """
Reference: https://docs.predibase.com/user-guide/inference/rest_api Reference: https://docs.predibase.com/user-guide/inference/rest_api
@ -126,11 +152,17 @@ class PredibaseChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict: def _validate_environment(
self, api_key: Optional[str], user_headers: dict, tenant_id: Optional[str]
) -> dict:
if api_key is None: if api_key is None:
raise ValueError( raise ValueError(
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params" "Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
) )
if tenant_id is None:
raise ValueError(
"Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=<MY-ID>)`) or in env - `PREDIBASE_TENANT_ID`."
)
headers = { headers = {
"content-type": "application/json", "content-type": "application/json",
"Authorization": "Bearer {}".format(api_key), "Authorization": "Bearer {}".format(api_key),
@ -304,7 +336,7 @@ class PredibaseChatCompletion(BaseLLM):
logger_fn=None, logger_fn=None,
headers: dict = {}, headers: dict = {},
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
headers = self._validate_environment(api_key, headers) headers = self._validate_environment(api_key, headers, tenant_id=tenant_id)
completion_url = "" completion_url = ""
input_text = "" input_text = ""
base_url = "https://serving.app.predibase.com" base_url = "https://serving.app.predibase.com"
@ -488,26 +520,19 @@ class PredibaseChatCompletion(BaseLLM):
logger_fn=None, logger_fn=None,
headers={}, headers={},
) -> CustomStreamWrapper: ) -> CustomStreamWrapper:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
data["stream"] = True data["stream"] = True
response = await self.async_handler.post(
url=api_base,
headers=headers,
data=json.dumps(data),
stream=True,
)
if response.status_code != 200:
raise PredibaseError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=None,
make_call=partial(
make_call,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model, model=model,
custom_llm_provider="predibase", custom_llm_provider="predibase",
logging_obj=logging_obj, logging_obj=logging_obj,

View file

@ -251,7 +251,7 @@ async def async_handle_prediction_response(
logs = "" logs = ""
while True and (status not in ["succeeded", "failed", "canceled"]): while True and (status not in ["succeeded", "failed", "canceled"]):
print_verbose(f"replicate: polling endpoint: {prediction_url}") print_verbose(f"replicate: polling endpoint: {prediction_url}")
await asyncio.sleep(0.5) await asyncio.sleep(0.5) # prevent replicate rate limit errors
response = await http_handler.get(prediction_url, headers=headers) response = await http_handler.get(prediction_url, headers=headers)
if response.status_code == 200: if response.status_code == 200:
response_data = response.json() response_data = response.json()

View file

@ -361,6 +361,7 @@ async def acompletion(
) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls) ) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls)
return response return response
except Exception as e: except Exception as e:
traceback.print_exc()
custom_llm_provider = custom_llm_provider or "openai" custom_llm_provider = custom_llm_provider or "openai"
raise exception_type( raise exception_type(
model=model, model=model,

View file

@ -1385,8 +1385,21 @@ def test_bedrock_claude_3_streaming():
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize(
"model",
[
"claude-3-opus-20240229",
"cohere.command-r-plus-v1:0", # bedrock
"databricks/databricks-dbrx-instruct", # databricks
"predibase/llama-3-8b-instruct", # predibase
"replicate/meta/meta-llama-3-8b-instruct", # replicate
],
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_claude_3_streaming_finish_reason(sync_mode): async def test_parallel_streaming_requests(sync_mode, model):
"""
Important prod test.
"""
try: try:
import threading import threading
@ -1398,7 +1411,7 @@ async def test_claude_3_streaming_finish_reason(sync_mode):
def sync_test_streaming(): def sync_test_streaming():
response: litellm.CustomStreamWrapper = litellm.acompletion( # type: ignore response: litellm.CustomStreamWrapper = litellm.acompletion( # type: ignore
model="claude-3-opus-20240229", model=model,
messages=messages, messages=messages,
stream=True, stream=True,
max_tokens=10, max_tokens=10,
@ -1415,7 +1428,7 @@ async def test_claude_3_streaming_finish_reason(sync_mode):
async def test_streaming(): async def test_streaming():
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
model="claude-3-opus-20240229", model=model,
messages=messages, messages=messages,
stream=True, stream=True,
max_tokens=10, max_tokens=10,
@ -1424,8 +1437,9 @@ async def test_claude_3_streaming_finish_reason(sync_mode):
# Add any assertions here to-check the response # Add any assertions here to-check the response
num_finish_reason = 0 num_finish_reason = 0
async for chunk in response: async for chunk in response:
print(f"chunk: {chunk}") print(f"type of chunk: {type(chunk)}")
if isinstance(chunk, ModelResponse): if isinstance(chunk, ModelResponse):
print(f"OUTSIDE CHUNK: {chunk.choices[0]}")
if chunk.choices[0].finish_reason is not None: if chunk.choices[0].finish_reason is not None:
num_finish_reason += 1 num_finish_reason += 1
assert num_finish_reason == 1 assert num_finish_reason == 1

View file

@ -32,7 +32,7 @@ from dataclasses import (
) )
import litellm._service_logger # for storing API inputs, outputs, and metadata import litellm._service_logger # for storing API inputs, outputs, and metadata
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.types.utils import CostPerToken, ProviderField, ModelInfo from litellm.types.utils import CostPerToken, ProviderField, ModelInfo
@ -10217,8 +10217,10 @@ class CustomStreamWrapper:
custom_llm_provider=None, custom_llm_provider=None,
logging_obj=None, logging_obj=None,
stream_options=None, stream_options=None,
make_call: Optional[Callable] = None,
): ):
self.model = model self.model = model
self.make_call = make_call
self.custom_llm_provider = custom_llm_provider self.custom_llm_provider = custom_llm_provider
self.logging_obj = logging_obj self.logging_obj = logging_obj
self.completion_stream = completion_stream self.completion_stream = completion_stream
@ -11769,8 +11771,20 @@ class CustomStreamWrapper:
custom_llm_provider=self.custom_llm_provider, custom_llm_provider=self.custom_llm_provider,
) )
async def fetch_stream(self):
if self.completion_stream is None and self.make_call is not None:
# Call make_call to get the completion stream
self.completion_stream = await self.make_call(
client=litellm.module_level_aclient
)
self._stream_iter = self.completion_stream.__aiter__()
return self.completion_stream
async def __anext__(self): async def __anext__(self):
try: try:
if self.completion_stream is None:
await self.fetch_stream()
if ( if (
self.custom_llm_provider == "openai" self.custom_llm_provider == "openai"
or self.custom_llm_provider == "azure" or self.custom_llm_provider == "azure"