fix: fix streaming with httpx client

prevent overwriting streams in parallel streaming calls
This commit is contained in:
Krrish Dholakia 2024-05-31 10:55:18 -07:00
parent aada7b4bd3
commit 93c3635b64
9 changed files with 182 additions and 82 deletions

View file

@ -1,7 +1,7 @@
# What is this?
## Controller file for Predibase Integration - https://predibase.com/
from functools import partial
import os, types
import json
from enum import Enum
@ -51,6 +51,32 @@ class PredibaseError(Exception):
) # Call the base class constructor with the parameters it needs
async def make_call(
client: AsyncHTTPHandler,
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise PredibaseError(status_code=response.status_code, message=response.text)
completion_stream = response.aiter_lines()
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
class PredibaseConfig:
"""
Reference: https://docs.predibase.com/user-guide/inference/rest_api
@ -126,11 +152,17 @@ class PredibaseChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
def _validate_environment(
self, api_key: Optional[str], user_headers: dict, tenant_id: Optional[str]
) -> dict:
if api_key is None:
raise ValueError(
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
)
if tenant_id is None:
raise ValueError(
"Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=<MY-ID>)`) or in env - `PREDIBASE_TENANT_ID`."
)
headers = {
"content-type": "application/json",
"Authorization": "Bearer {}".format(api_key),
@ -304,7 +336,7 @@ class PredibaseChatCompletion(BaseLLM):
logger_fn=None,
headers: dict = {},
) -> Union[ModelResponse, CustomStreamWrapper]:
headers = self._validate_environment(api_key, headers)
headers = self._validate_environment(api_key, headers, tenant_id=tenant_id)
completion_url = ""
input_text = ""
base_url = "https://serving.app.predibase.com"
@ -488,26 +520,19 @@ class PredibaseChatCompletion(BaseLLM):
logger_fn=None,
headers={},
) -> CustomStreamWrapper:
self.async_handler = AsyncHTTPHandler(
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
)
data["stream"] = True
response = await self.async_handler.post(
url=api_base,
headers=headers,
data=json.dumps(data),
stream=True,
)
if response.status_code != 200:
raise PredibaseError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
completion_stream=None,
make_call=partial(
make_call,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="predibase",
logging_obj=logging_obj,