fix: fix streaming with httpx client

prevent overwriting streams in parallel streaming calls
This commit is contained in:
Krrish Dholakia 2024-05-31 10:55:18 -07:00
parent aada7b4bd3
commit 93c3635b64
9 changed files with 182 additions and 82 deletions

View file

@ -1,5 +1,6 @@
# What is this?
## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
from functools import partial
import os, types
import json
from enum import Enum
@ -123,7 +124,7 @@ class DatabricksConfig:
original_chunk = None # this is used for function/tool calling
chunk_data = chunk_data.replace("data:", "")
chunk_data = chunk_data.strip()
if len(chunk_data) == 0:
if len(chunk_data) == 0 or chunk_data == "[DONE]":
return {
"text": "",
"is_finished": is_finished,
@ -221,6 +222,32 @@ class DatabricksEmbeddingConfig:
return optional_params
async def make_call(
client: AsyncHTTPHandler,
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise DatabricksError(status_code=response.status_code, message=response.text)
completion_stream = response.aiter_lines()
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
class DatabricksChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
@ -354,29 +381,21 @@ class DatabricksChatCompletion(BaseLLM):
litellm_params=None,
logger_fn=None,
headers={},
):
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper:
data["stream"] = True
try:
response = await self.async_handler.post(
api_base, headers=headers, data=json.dumps(data), stream=True
)
response.raise_for_status()
completion_stream = response.aiter_lines()
except httpx.HTTPStatusError as e:
raise DatabricksError(
status_code=e.response.status_code, message=response.text
)
except httpx.TimeoutException as e:
raise DatabricksError(status_code=408, message="Timeout error occurred.")
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
completion_stream=None,
make_call=partial(
make_call,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="databricks",
logging_obj=logging_obj,
@ -475,6 +494,8 @@ class DatabricksChatCompletion(BaseLLM):
},
)
if acompletion == True:
if client is not None and isinstance(client, HTTPHandler):
client = None
if (
stream is not None and stream == True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
@ -496,6 +517,7 @@ class DatabricksChatCompletion(BaseLLM):
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
client=client,
)
else:
return self.acompletion_function(