mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix: fix streaming with httpx client
prevent overwriting streams in parallel streaming calls
This commit is contained in:
parent
aada7b4bd3
commit
93c3635b64
9 changed files with 182 additions and 82 deletions
|
@ -1,5 +1,6 @@
|
|||
# What is this?
|
||||
## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
||||
from functools import partial
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
|
@ -123,7 +124,7 @@ class DatabricksConfig:
|
|||
original_chunk = None # this is used for function/tool calling
|
||||
chunk_data = chunk_data.replace("data:", "")
|
||||
chunk_data = chunk_data.strip()
|
||||
if len(chunk_data) == 0:
|
||||
if len(chunk_data) == 0 or chunk_data == "[DONE]":
|
||||
return {
|
||||
"text": "",
|
||||
"is_finished": is_finished,
|
||||
|
@ -221,6 +222,32 @@ class DatabricksEmbeddingConfig:
|
|||
return optional_params
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: AsyncHTTPHandler,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
):
|
||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise DatabricksError(status_code=response.status_code, message=response.text)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=completion_stream, # Pass the completion stream for logging
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class DatabricksChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
@ -354,29 +381,21 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
|
||||
data["stream"] = True
|
||||
try:
|
||||
response = await self.async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data), stream=True
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise DatabricksError(
|
||||
status_code=e.response.status_code, message=response.text
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
raise DatabricksError(status_code=408, message="Timeout error occurred.")
|
||||
except Exception as e:
|
||||
raise DatabricksError(status_code=500, message=str(e))
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="databricks",
|
||||
logging_obj=logging_obj,
|
||||
|
@ -475,6 +494,8 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
},
|
||||
)
|
||||
if acompletion == True:
|
||||
if client is not None and isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if (
|
||||
stream is not None and stream == True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
|
@ -496,6 +517,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
return self.acompletion_function(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue