mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix: fix streaming with httpx client
prevent overwriting streams in parallel streaming calls
This commit is contained in:
parent
aada7b4bd3
commit
93c3635b64
9 changed files with 182 additions and 82 deletions
|
@ -5,6 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*
|
||||||
### INIT VARIABLES ###
|
### INIT VARIABLES ###
|
||||||
import threading, requests, os
|
import threading, requests, os
|
||||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
from litellm._logging import (
|
from litellm._logging import (
|
||||||
set_verbose,
|
set_verbose,
|
||||||
|
@ -212,6 +213,7 @@ add_function_to_prompt: bool = (
|
||||||
)
|
)
|
||||||
client_session: Optional[httpx.Client] = None
|
client_session: Optional[httpx.Client] = None
|
||||||
aclient_session: Optional[httpx.AsyncClient] = None
|
aclient_session: Optional[httpx.AsyncClient] = None
|
||||||
|
module_level_aclient = AsyncHTTPHandler()
|
||||||
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
||||||
model_cost_map_url: str = (
|
model_cost_map_url: str = (
|
||||||
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
||||||
## V1 - covers cohere + anthropic claude-3 support
|
## V1 - covers cohere + anthropic claude-3 support
|
||||||
|
from functools import partial
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -145,6 +145,37 @@ class AmazonCohereChatConfig:
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
async def make_call(
|
||||||
|
client: Optional[AsyncHTTPHandler],
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
):
|
||||||
|
if client is None:
|
||||||
|
client = AsyncHTTPHandler() # Create a new client if none provided
|
||||||
|
|
||||||
|
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise BedrockError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
decoder = AWSEventStreamDecoder(model=model)
|
||||||
|
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
|
||||||
|
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=completion_stream, # Pass the completion stream for logging
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class BedrockLLM(BaseLLM):
|
class BedrockLLM(BaseLLM):
|
||||||
"""
|
"""
|
||||||
Example call
|
Example call
|
||||||
|
@ -968,39 +999,24 @@ class BedrockLLM(BaseLLM):
|
||||||
headers={},
|
headers={},
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
) -> CustomStreamWrapper:
|
) -> CustomStreamWrapper:
|
||||||
if client is None:
|
# The call is not made here; instead, we prepare the necessary objects for the stream.
|
||||||
_params = {}
|
|
||||||
if timeout is not None:
|
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
|
||||||
timeout = httpx.Timeout(timeout)
|
|
||||||
_params["timeout"] = timeout
|
|
||||||
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
|
||||||
else:
|
|
||||||
self.client = client # type: ignore
|
|
||||||
|
|
||||||
response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise BedrockError(status_code=response.status_code, message=response.text)
|
|
||||||
|
|
||||||
decoder = AWSEventStreamDecoder(model=model)
|
|
||||||
|
|
||||||
completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024))
|
|
||||||
streaming_response = CustomStreamWrapper(
|
streaming_response = CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_call,
|
||||||
|
client=client,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="bedrock",
|
custom_llm_provider="bedrock",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=messages,
|
|
||||||
api_key="",
|
|
||||||
original_response=streaming_response,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
)
|
|
||||||
|
|
||||||
return streaming_response
|
return streaming_response
|
||||||
|
|
||||||
def embedding(self, *args, **kwargs):
|
def embedding(self, *args, **kwargs):
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import httpx, asyncio
|
import httpx, asyncio, traceback
|
||||||
from typing import Optional, Union, Mapping, Any
|
from typing import Optional, Union, Mapping, Any
|
||||||
|
|
||||||
# https://www.python-httpx.org/advanced/timeouts
|
# https://www.python-httpx.org/advanced/timeouts
|
||||||
|
@ -48,11 +48,17 @@ class AsyncHTTPHandler:
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
req = self.client.build_request(
|
try:
|
||||||
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
req = self.client.build_request(
|
||||||
)
|
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||||
response = await self.client.send(req, stream=stream)
|
)
|
||||||
return response
|
response = await self.client.send(req, stream=stream)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
||||||
|
from functools import partial
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -123,7 +124,7 @@ class DatabricksConfig:
|
||||||
original_chunk = None # this is used for function/tool calling
|
original_chunk = None # this is used for function/tool calling
|
||||||
chunk_data = chunk_data.replace("data:", "")
|
chunk_data = chunk_data.replace("data:", "")
|
||||||
chunk_data = chunk_data.strip()
|
chunk_data = chunk_data.strip()
|
||||||
if len(chunk_data) == 0:
|
if len(chunk_data) == 0 or chunk_data == "[DONE]":
|
||||||
return {
|
return {
|
||||||
"text": "",
|
"text": "",
|
||||||
"is_finished": is_finished,
|
"is_finished": is_finished,
|
||||||
|
@ -221,6 +222,32 @@ class DatabricksEmbeddingConfig:
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
async def make_call(
|
||||||
|
client: AsyncHTTPHandler,
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
):
|
||||||
|
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise DatabricksError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
completion_stream = response.aiter_lines()
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=completion_stream, # Pass the completion stream for logging
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class DatabricksChatCompletion(BaseLLM):
|
class DatabricksChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -354,29 +381,21 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
):
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
self.async_handler = AsyncHTTPHandler(
|
) -> CustomStreamWrapper:
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
|
||||||
)
|
|
||||||
data["stream"] = True
|
data["stream"] = True
|
||||||
try:
|
|
||||||
response = await self.async_handler.post(
|
|
||||||
api_base, headers=headers, data=json.dumps(data), stream=True
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
completion_stream = response.aiter_lines()
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
raise DatabricksError(
|
|
||||||
status_code=e.response.status_code, message=response.text
|
|
||||||
)
|
|
||||||
except httpx.TimeoutException as e:
|
|
||||||
raise DatabricksError(status_code=408, message="Timeout error occurred.")
|
|
||||||
except Exception as e:
|
|
||||||
raise DatabricksError(status_code=500, message=str(e))
|
|
||||||
|
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_call,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="databricks",
|
custom_llm_provider="databricks",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
@ -475,6 +494,8 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if acompletion == True:
|
if acompletion == True:
|
||||||
|
if client is not None and isinstance(client, HTTPHandler):
|
||||||
|
client = None
|
||||||
if (
|
if (
|
||||||
stream is not None and stream == True
|
stream is not None and stream == True
|
||||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||||
|
@ -496,6 +517,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.acompletion_function(
|
return self.acompletion_function(
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Controller file for Predibase Integration - https://predibase.com/
|
## Controller file for Predibase Integration - https://predibase.com/
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -51,6 +51,32 @@ class PredibaseError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
async def make_call(
|
||||||
|
client: AsyncHTTPHandler,
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
):
|
||||||
|
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise PredibaseError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
completion_stream = response.aiter_lines()
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=completion_stream, # Pass the completion stream for logging
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class PredibaseConfig:
|
class PredibaseConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
Reference: https://docs.predibase.com/user-guide/inference/rest_api
|
||||||
|
@ -126,11 +152,17 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def _validate_environment(self, api_key: Optional[str], user_headers: dict) -> dict:
|
def _validate_environment(
|
||||||
|
self, api_key: Optional[str], user_headers: dict, tenant_id: Optional[str]
|
||||||
|
) -> dict:
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
"Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
|
||||||
)
|
)
|
||||||
|
if tenant_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing Predibase Tenant ID - Required for making the request. Set dynamically (e.g. `completion(..tenant_id=<MY-ID>)`) or in env - `PREDIBASE_TENANT_ID`."
|
||||||
|
)
|
||||||
headers = {
|
headers = {
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
"Authorization": "Bearer {}".format(api_key),
|
"Authorization": "Bearer {}".format(api_key),
|
||||||
|
@ -304,7 +336,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers: dict = {},
|
headers: dict = {},
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
headers = self._validate_environment(api_key, headers)
|
headers = self._validate_environment(api_key, headers, tenant_id=tenant_id)
|
||||||
completion_url = ""
|
completion_url = ""
|
||||||
input_text = ""
|
input_text = ""
|
||||||
base_url = "https://serving.app.predibase.com"
|
base_url = "https://serving.app.predibase.com"
|
||||||
|
@ -488,26 +520,19 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
) -> CustomStreamWrapper:
|
) -> CustomStreamWrapper:
|
||||||
self.async_handler = AsyncHTTPHandler(
|
|
||||||
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
|
||||||
)
|
|
||||||
data["stream"] = True
|
data["stream"] = True
|
||||||
response = await self.async_handler.post(
|
|
||||||
url=api_base,
|
|
||||||
headers=headers,
|
|
||||||
data=json.dumps(data),
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise PredibaseError(
|
|
||||||
status_code=response.status_code, message=response.text
|
|
||||||
)
|
|
||||||
|
|
||||||
completion_stream = response.aiter_lines()
|
|
||||||
|
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=completion_stream,
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_call,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="predibase",
|
custom_llm_provider="predibase",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
|
|
@ -251,7 +251,7 @@ async def async_handle_prediction_response(
|
||||||
logs = ""
|
logs = ""
|
||||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5) # prevent replicate rate limit errors
|
||||||
response = await http_handler.get(prediction_url, headers=headers)
|
response = await http_handler.get(prediction_url, headers=headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
|
|
|
@ -361,6 +361,7 @@ async def acompletion(
|
||||||
) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls)
|
) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
custom_llm_provider = custom_llm_provider or "openai"
|
custom_llm_provider = custom_llm_provider or "openai"
|
||||||
raise exception_type(
|
raise exception_type(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -1385,8 +1385,21 @@ def test_bedrock_claude_3_streaming():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"claude-3-opus-20240229",
|
||||||
|
"cohere.command-r-plus-v1:0", # bedrock
|
||||||
|
"databricks/databricks-dbrx-instruct", # databricks
|
||||||
|
"predibase/llama-3-8b-instruct", # predibase
|
||||||
|
"replicate/meta/meta-llama-3-8b-instruct", # replicate
|
||||||
|
],
|
||||||
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_claude_3_streaming_finish_reason(sync_mode):
|
async def test_parallel_streaming_requests(sync_mode, model):
|
||||||
|
"""
|
||||||
|
Important prod test.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
@ -1398,7 +1411,7 @@ async def test_claude_3_streaming_finish_reason(sync_mode):
|
||||||
|
|
||||||
def sync_test_streaming():
|
def sync_test_streaming():
|
||||||
response: litellm.CustomStreamWrapper = litellm.acompletion( # type: ignore
|
response: litellm.CustomStreamWrapper = litellm.acompletion( # type: ignore
|
||||||
model="claude-3-opus-20240229",
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
|
@ -1415,7 +1428,7 @@ async def test_claude_3_streaming_finish_reason(sync_mode):
|
||||||
|
|
||||||
async def test_streaming():
|
async def test_streaming():
|
||||||
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
||||||
model="claude-3-opus-20240229",
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
|
@ -1424,8 +1437,9 @@ async def test_claude_3_streaming_finish_reason(sync_mode):
|
||||||
# Add any assertions here to-check the response
|
# Add any assertions here to-check the response
|
||||||
num_finish_reason = 0
|
num_finish_reason = 0
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print(f"chunk: {chunk}")
|
print(f"type of chunk: {type(chunk)}")
|
||||||
if isinstance(chunk, ModelResponse):
|
if isinstance(chunk, ModelResponse):
|
||||||
|
print(f"OUTSIDE CHUNK: {chunk.choices[0]}")
|
||||||
if chunk.choices[0].finish_reason is not None:
|
if chunk.choices[0].finish_reason is not None:
|
||||||
num_finish_reason += 1
|
num_finish_reason += 1
|
||||||
assert num_finish_reason == 1
|
assert num_finish_reason == 1
|
||||||
|
|
|
@ -32,7 +32,7 @@ from dataclasses import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.types.utils import CostPerToken, ProviderField, ModelInfo
|
from litellm.types.utils import CostPerToken, ProviderField, ModelInfo
|
||||||
|
|
||||||
|
@ -10214,8 +10214,10 @@ class CustomStreamWrapper:
|
||||||
custom_llm_provider=None,
|
custom_llm_provider=None,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
stream_options=None,
|
stream_options=None,
|
||||||
|
make_call: Optional[Callable] = None,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.make_call = make_call
|
||||||
self.custom_llm_provider = custom_llm_provider
|
self.custom_llm_provider = custom_llm_provider
|
||||||
self.logging_obj = logging_obj
|
self.logging_obj = logging_obj
|
||||||
self.completion_stream = completion_stream
|
self.completion_stream = completion_stream
|
||||||
|
@ -11766,8 +11768,20 @@ class CustomStreamWrapper:
|
||||||
custom_llm_provider=self.custom_llm_provider,
|
custom_llm_provider=self.custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def fetch_stream(self):
|
||||||
|
if self.completion_stream is None and self.make_call is not None:
|
||||||
|
# Call make_call to get the completion stream
|
||||||
|
self.completion_stream = await self.make_call(
|
||||||
|
client=litellm.module_level_aclient
|
||||||
|
)
|
||||||
|
self._stream_iter = self.completion_stream.__aiter__()
|
||||||
|
|
||||||
|
return self.completion_stream
|
||||||
|
|
||||||
async def __anext__(self):
|
async def __anext__(self):
|
||||||
try:
|
try:
|
||||||
|
if self.completion_stream is None:
|
||||||
|
await self.fetch_stream()
|
||||||
if (
|
if (
|
||||||
self.custom_llm_provider == "openai"
|
self.custom_llm_provider == "openai"
|
||||||
or self.custom_llm_provider == "azure"
|
or self.custom_llm_provider == "azure"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue