mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix: fix streaming with httpx client
prevent overwriting streams in parallel streaming calls
This commit is contained in:
parent
8bcb53137e
commit
3896e3e88f
9 changed files with 182 additions and 82 deletions
|
@ -1,7 +1,7 @@
|
|||
# What is this?
|
||||
## Controller file for Predibase Integration - https://predibase.com/
|
||||
|
||||
|
||||
from functools import partial
|
||||
import os, types
|
||||
import json
|
||||
from enum import Enum
|
||||
|
@ -51,6 +51,32 @@ class PredibaseError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: AsyncHTTPHandler,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
):
|
||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise PredibaseError(status_code=response.status_code, message=response.text)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=completion_stream, # Pass the completion stream for logging
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class PredibaseConfig:
|
||||
"""
|
||||
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
||||
|
@ -126,11 +152,17 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
|
||||
def _validate_environment(
|
||||
self, api_key: Optional[str], user_headers: dict, tenant_id: Optional[str]
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
||||
)
|
||||
if tenant_id is None:
|
||||
raise ValueError(
|
||||
"Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=<MY-ID>)`) or in env - `PREDIBASE_TENANT_ID`."
|
||||
)
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
|
@ -304,7 +336,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
logger_fn=None,
|
||||
headers: dict = {},
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
headers = self._validate_environment(api_key, headers)
|
||||
headers = self._validate_environment(api_key, headers, tenant_id=tenant_id)
|
||||
completion_url = ""
|
||||
input_text = ""
|
||||
base_url = "https://serving.app.predibase.com"
|
||||
|
@ -488,26 +520,19 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
logger_fn=None,
|
||||
headers={},
|
||||
) -> CustomStreamWrapper:
|
||||
self.async_handler = AsyncHTTPHandler(
|
||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
)
|
||||
data["stream"] = True
|
||||
response = await self.async_handler.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise PredibaseError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="predibase",
|
||||
logging_obj=logging_obj,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue