forked from phoenix/litellm-mirror
fix: fix streaming with httpx client
prevent overwriting streams in parallel streaming calls
This commit is contained in:
parent
aada7b4bd3
commit
93c3635b64
9 changed files with 182 additions and 82 deletions
|
@ -5,6 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*
|
|||
### INIT VARIABLES ###
|
||||
import threading, requests, os
|
||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.caching import Cache
|
||||
from litellm._logging import (
|
||||
set_verbose,
|
||||
|
@ -212,6 +213,7 @@ add_function_to_prompt: bool = (
|
|||
)
|
||||
client_session: Optional[httpx.Client] = None
|
||||
aclient_session: Optional[httpx.AsyncClient] = None
|
||||
module_level_aclient = AsyncHTTPHandler()
|
||||
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
||||
model_cost_map_url: str = (
|
||||
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# What is this?
|
||||
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
||||
## V1 - covers cohere + anthropic claude-3 support
|
||||
|
||||
from functools import partial
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
|
@ -145,6 +145,37 @@ class AmazonCohereChatConfig:
|
|||
return optional_params
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
):
|
||||
if client is None:
|
||||
client = AsyncHTTPHandler() # Create a new client if none provided
|
||||
|
||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(status_code=response.status_code, message=response.text)
|
||||
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=completion_stream, # Pass the completion stream for logging
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class BedrockLLM(BaseLLM):
|
||||
"""
|
||||
Example call
|
||||
|
@ -968,39 +999,24 @@ class BedrockLLM(BaseLLM):
|
|||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.client = client # type: ignore
|
||||
# The call is not made here; instead, we prepare the necessary objects for the stream.
|
||||
|
||||
response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(status_code=response.status_code, message=response.text)
|
||||
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
|
||||
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=streaming_response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
|
||||
def embedding(self, *args, **kwargs):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import httpx, asyncio
|
||||
import httpx, asyncio, traceback
|
||||
from typing import Optional, Union, Mapping, Any
|
||||
|
||||
# https://www.python-httpx.org/advanced/timeouts
|
||||
|
@ -48,11 +48,17 @@ class AsyncHTTPHandler:
|
|||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
req = self.client.build_request(
|
||||
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||
)
|
||||
response = await self.client.send(req, stream=stream)
|
||||
return response
|
||||
try:
|
||||
req = self.client.build_request(
|
||||
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||
)
|
||||
response = await self.client.send(req, stream=stream)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# What is this?
|
||||
## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
||||
from functools import partial
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
|
@ -123,7 +124,7 @@ class DatabricksConfig:
|
|||
original_chunk = None # this is used for function/tool calling
|
||||
chunk_data = chunk_data.replace("data:", "")
|
||||
chunk_data = chunk_data.strip()
|
||||
if len(chunk_data) == 0:
|
||||
if len(chunk_data) == 0 or chunk_data == "[DONE]":
|
||||
return {
|
||||
"text": "",
|
||||
"is_finished": is_finished,
|
||||
|
@ -221,6 +222,32 @@ class DatabricksEmbeddingConfig:
|
|||
return optional_params
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: AsyncHTTPHandler,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
):
|
||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise DatabricksError(status_code=response.status_code, message=response.text)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=completion_stream, # Pass the completion stream for logging
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class DatabricksChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
@ -354,29 +381,21 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
|
||||
data["stream"] = True
|
||||
try:
|
||||
response = await self.async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data), stream=True
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise DatabricksError(
|
||||
status_code=e.response.status_code, message=response.text
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
raise DatabricksError(status_code=408, message="Timeout error occurred.")
|
||||
except Exception as e:
|
||||
raise DatabricksError(status_code=500, message=str(e))
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="databricks",
|
||||
logging_obj=logging_obj,
|
||||
|
@ -475,6 +494,8 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
if acompletion == True:
|
||||
if client is not None and isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if (
|
||||
stream is not None and stream == True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
|
@ -496,6 +517,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
return self.acompletion_function(
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# What is this?
|
||||
## Controller file for Predibase Integration - https://predibase.com/
|
||||
|
||||
|
||||
from functools import partial
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
|
@ -51,6 +51,32 @@ class PredibaseError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: AsyncHTTPHandler,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
):
|
||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise PredibaseError(status_code=response.status_code, message=response.text)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=completion_stream, # Pass the completion stream for logging
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class PredibaseConfig:
|
||||
"""
|
||||
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
||||
|
@ -126,11 +152,17 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
|
||||
def _validate_environment(
|
||||
self, api_key: Optional[str], user_headers: dict, tenant_id: Optional[str]
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
||||
)
|
||||
if tenant_id is None:
|
||||
raise ValueError(
|
||||
"Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=<MY-ID>)`) or in env - `PREDIBASE_TENANT_ID`."
|
||||
)
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
|
@ -304,7 +336,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
logger_fn=None,
|
||||
headers: dict = {},
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
headers = self._validate_environment(api_key, headers)
|
||||
headers = self._validate_environment(api_key, headers, tenant_id=tenant_id)
|
||||
completion_url = ""
|
||||
input_text = ""
|
||||
base_url = "https://serving.app.predibase.com"
|
||||
|
@ -488,26 +520,19 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
logger_fn=None,
|
||||
headers={},
|
||||
) -> CustomStreamWrapper:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
data["stream"] = True
|
||||
response = await self.async_handler.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise PredibaseError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="predibase",
|
||||
logging_obj=logging_obj,
|
||||
|
|
|
@ -251,7 +251,7 @@ async def async_handle_prediction_response(
|
|||
logs = ""
|
||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||
await asyncio.sleep(0.5)
|
||||
await asyncio.sleep(0.5) # prevent replicate rate limit errors
|
||||
response = await http_handler.get(prediction_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
|
|
|
@ -361,6 +361,7 @@ async def acompletion(
|
|||
) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls)
|
||||
return response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
custom_llm_provider = custom_llm_provider or "openai"
|
||||
raise exception_type(
|
||||
model=model,
|
||||
|
|
|
@ -1385,8 +1385,21 @@ def test_bedrock_claude_3_streaming():
|
|||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"claude-3-opus-20240229",
|
||||
"cohere.command-r-plus-v1:0", # bedrock
|
||||
"databricks/databricks-dbrx-instruct", # databricks
|
||||
"predibase/llama-3-8b-instruct", # predibase
|
||||
"replicate/meta/meta-llama-3-8b-instruct", # replicate
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude_3_streaming_finish_reason(sync_mode):
|
||||
async def test_parallel_streaming_requests(sync_mode, model):
|
||||
"""
|
||||
Important prod test.
|
||||
"""
|
||||
try:
|
||||
import threading
|
||||
|
||||
|
@ -1398,7 +1411,7 @@ async def test_claude_3_streaming_finish_reason(sync_mode):
|
|||
|
||||
def sync_test_streaming():
|
||||
response: litellm.CustomStreamWrapper = litellm.acompletion( # type: ignore
|
||||
model="claude-3-opus-20240229",
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
max_tokens=10,
|
||||
|
@ -1415,7 +1428,7 @@ async def test_claude_3_streaming_finish_reason(sync_mode):
|
|||
|
||||
async def test_streaming():
|
||||
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
||||
model="claude-3-opus-20240229",
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
max_tokens=10,
|
||||
|
@ -1424,8 +1437,9 @@ async def test_claude_3_streaming_finish_reason(sync_mode):
|
|||
# Add any assertions here to-check the response
|
||||
num_finish_reason = 0
|
||||
async for chunk in response:
|
||||
print(f"chunk: {chunk}")
|
||||
print(f"type of chunk: {type(chunk)}")
|
||||
if isinstance(chunk, ModelResponse):
|
||||
print(f"OUTSIDE CHUNK: {chunk.choices[0]}")
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
num_finish_reason += 1
|
||||
assert num_finish_reason == 1
|
||||
|
|
|
@ -32,7 +32,7 @@ from dataclasses import (
|
|||
)
|
||||
|
||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
from litellm.caching import DualCache
|
||||
from litellm.types.utils import CostPerToken, ProviderField, ModelInfo
|
||||
|
||||
|
@ -10214,8 +10214,10 @@ class CustomStreamWrapper:
|
|||
custom_llm_provider=None,
|
||||
logging_obj=None,
|
||||
stream_options=None,
|
||||
make_call: Optional[Callable] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.make_call = make_call
|
||||
self.custom_llm_provider = custom_llm_provider
|
||||
self.logging_obj = logging_obj
|
||||
self.completion_stream = completion_stream
|
||||
|
@ -11766,8 +11768,20 @@ class CustomStreamWrapper:
|
|||
custom_llm_provider=self.custom_llm_provider,
|
||||
)
|
||||
|
||||
async def fetch_stream(self):
|
||||
if self.completion_stream is None and self.make_call is not None:
|
||||
# Call make_call to get the completion stream
|
||||
self.completion_stream = await self.make_call(
|
||||
client=litellm.module_level_aclient
|
||||
)
|
||||
self._stream_iter = self.completion_stream.__aiter__()
|
||||
|
||||
return self.completion_stream
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
if self.completion_stream is None:
|
||||
await self.fetch_stream()
|
||||
if (
|
||||
self.custom_llm_provider == "openai"
|
||||
or self.custom_llm_provider == "azure"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue