mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
refactor(openai.py): support aiohttp streaming
This commit is contained in:
parent
bba62b56d3
commit
c053782d96
5 changed files with 108 additions and 42 deletions
|
@ -1,11 +1,10 @@
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import types, requests
|
import types, requests
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
from litellm.utils import ModelResponse, Choices, Message
|
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
|
||||||
class OpenAIError(Exception):
|
class OpenAIError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
|
@ -219,7 +218,12 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if acompletion is True:
|
||||||
|
if optional_params.get("stream", False):
|
||||||
|
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
|
||||||
|
else:
|
||||||
|
return self.acompletion(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
|
||||||
|
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||||
response = self._client_session.post(
|
response = self._client_session.post(
|
||||||
url=api_base,
|
url=api_base,
|
||||||
json=data,
|
json=data,
|
||||||
|
@ -231,8 +235,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
return response.iter_lines()
|
return response.iter_lines()
|
||||||
elif acompletion is True:
|
|
||||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
|
|
||||||
else:
|
else:
|
||||||
response = self._client_session.post(
|
response = self._client_session.post(
|
||||||
url=api_base,
|
url=api_base,
|
||||||
|
@ -273,7 +276,12 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
import traceback
|
import traceback
|
||||||
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
||||||
|
|
||||||
async def acompletion(self, api_base: str, data: dict, headers: dict, model_response: ModelResponse):
|
async def acompletion(self,
|
||||||
|
logging_obj,
|
||||||
|
api_base: str,
|
||||||
|
data: dict, headers: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
model: str):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(api_base, json=data, headers=headers) as response:
|
async with session.post(api_base, json=data, headers=headers) as response:
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
|
@ -284,6 +292,25 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
|
||||||
|
|
||||||
|
async def async_streaming(self,
|
||||||
|
logging_obj,
|
||||||
|
api_base: str,
|
||||||
|
data: dict, headers: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
model: str):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(api_base, json=data, headers=headers) as response:
|
||||||
|
# Check if the request was successful (status code 200)
|
||||||
|
if response.status != 200:
|
||||||
|
raise OpenAIError(status_code=response.status, message=await response.text())
|
||||||
|
|
||||||
|
# Handle the streamed response
|
||||||
|
# async for line in response.content:
|
||||||
|
# print(line)
|
||||||
|
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
||||||
|
async for transformed_chunk in streamwrapper:
|
||||||
|
yield transformed_chunk
|
||||||
|
|
||||||
def embedding(self,
|
def embedding(self,
|
||||||
model: str,
|
model: str,
|
||||||
input: list,
|
input: list,
|
||||||
|
|
|
@ -137,9 +137,13 @@ async def acompletion(model: str, messages: List = [], *args, **kwargs):
|
||||||
|
|
||||||
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
|
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
|
||||||
|
|
||||||
if custom_llm_provider == "openai" or custom_llm_provider == "azure": # currently implemented aiohttp calls for just azure and openai, soon all.
|
|
||||||
# Await normally
|
if (custom_llm_provider == "openai" or custom_llm_provider == "azure"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
response = await completion(*args, **kwargs)
|
if kwargs.get("stream", False):
|
||||||
|
response = completion(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
# Await normally
|
||||||
|
response = await completion(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
# Call the synchronous function using run_in_executor
|
# Call the synchronous function using run_in_executor
|
||||||
response = await loop.run_in_executor(None, func_with_context)
|
response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -147,6 +151,7 @@ async def acompletion(model: str, messages: List = [], *args, **kwargs):
|
||||||
# do not change this
|
# do not change this
|
||||||
# for stream = True, always return an async generator
|
# for stream = True, always return an async generator
|
||||||
# See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193
|
# See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193
|
||||||
|
# return response
|
||||||
return(
|
return(
|
||||||
line
|
line
|
||||||
async for line in response
|
async for line in response
|
||||||
|
@ -515,7 +520,7 @@ def completion(
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if optional_params.get("stream", False) and acompletion is False:
|
||||||
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
|
response = CustomStreamWrapper(response, model, custom_llm_provider=custom_llm_provider, logging_obj=logging)
|
||||||
return response
|
return response
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -37,7 +37,7 @@ def test_async_response():
|
||||||
|
|
||||||
response = asyncio.run(test_get_response())
|
response = asyncio.run(test_get_response())
|
||||||
# print(response)
|
# print(response)
|
||||||
test_async_response()
|
# test_async_response()
|
||||||
|
|
||||||
def test_get_response_streaming():
|
def test_get_response_streaming():
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -45,8 +45,6 @@ def test_get_response_streaming():
|
||||||
user_message = "Hello, how are you?"
|
user_message = "Hello, how are you?"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
try:
|
try:
|
||||||
import litellm
|
|
||||||
litellm.set_verbose = True
|
|
||||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
||||||
print(type(response))
|
print(type(response))
|
||||||
|
|
||||||
|
@ -56,11 +54,11 @@ def test_get_response_streaming():
|
||||||
print(is_async_generator)
|
print(is_async_generator)
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
|
i = 0
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
token = chunk["choices"][0]["delta"].get("content", "")
|
token = chunk["choices"][0]["delta"].get("content", "")
|
||||||
output += token
|
output += token
|
||||||
print(output)
|
print(f"output: {output}")
|
||||||
|
|
||||||
assert output is not None, "output cannot be None."
|
assert output is not None, "output cannot be None."
|
||||||
assert isinstance(output, str), "output needs to be of type str"
|
assert isinstance(output, str), "output needs to be of type str"
|
||||||
assert len(output) > 0, "Length of output needs to be greater than 0."
|
assert len(output) > 0, "Length of output needs to be greater than 0."
|
||||||
|
@ -71,3 +69,35 @@ def test_get_response_streaming():
|
||||||
asyncio.run(test_async_call())
|
asyncio.run(test_async_call())
|
||||||
|
|
||||||
|
|
||||||
|
# test_get_response_streaming()
|
||||||
|
|
||||||
|
def test_get_response_non_openai_streaming():
|
||||||
|
import asyncio
|
||||||
|
async def test_async_call():
|
||||||
|
user_message = "Hello, how are you?"
|
||||||
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
try:
|
||||||
|
response = await acompletion(model="command-nightly", messages=messages, stream=True)
|
||||||
|
print(type(response))
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
is_async_generator = inspect.isasyncgen(response)
|
||||||
|
print(is_async_generator)
|
||||||
|
|
||||||
|
output = ""
|
||||||
|
i = 0
|
||||||
|
async for chunk in response:
|
||||||
|
token = chunk["choices"][0]["delta"].get("content", "")
|
||||||
|
output += token
|
||||||
|
print(f"output: {output}")
|
||||||
|
assert output is not None, "output cannot be None."
|
||||||
|
assert isinstance(output, str), "output needs to be of type str"
|
||||||
|
assert len(output) > 0, "Length of output needs to be greater than 0."
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
|
return response
|
||||||
|
asyncio.run(test_async_call())
|
||||||
|
|
||||||
|
test_get_response_non_openai_streaming()
|
||||||
|
|
|
@ -60,5 +60,4 @@ def test_stream_chunk_builder():
|
||||||
print(role, content, finish_reason)
|
print(role, content, finish_reason)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception("stream_chunk_builder failed to rebuild response", e)
|
raise Exception("stream_chunk_builder failed to rebuild response", e)
|
||||||
test_stream_chunk_builder()
|
|
||||||
|
|
||||||
|
|
|
@ -3998,14 +3998,15 @@ class CustomStreamWrapper:
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
if str_line == "data: [DONE]":
|
if "data: [DONE]" in str_line:
|
||||||
# anyscale returns a [DONE] special char for streaming, this cannot be json loaded. This is the end of stream
|
# anyscale returns a [DONE] special char for streaming, this cannot be json loaded. This is the end of stream
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
|
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
|
||||||
elif str_line.startswith("data:"):
|
elif str_line.startswith("data:") and len(str_line[5:]) > 0:
|
||||||
data_json = json.loads(str_line[5:])
|
str_line = str_line[5:]
|
||||||
|
data_json = json.loads(str_line)
|
||||||
print_verbose(f"delta content: {data_json['choices'][0]['delta']}")
|
print_verbose(f"delta content: {data_json['choices'][0]['delta']}")
|
||||||
text = data_json["choices"][0]["delta"].get("content", "")
|
text = data_json["choices"][0]["delta"].get("content", "")
|
||||||
if data_json["choices"][0].get("finish_reason", None):
|
if data_json["choices"][0].get("finish_reason", None):
|
||||||
|
@ -4104,72 +4105,61 @@ class CustomStreamWrapper:
|
||||||
raise Exception(chunk["error"])
|
raise Exception(chunk["error"])
|
||||||
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
|
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
## needs to handle the empty string case (even starting chunk can be an empty string)
|
def chunk_creator(self, chunk):
|
||||||
def __next__(self):
|
|
||||||
model_response = ModelResponse(stream=True, model=self.model)
|
model_response = ModelResponse(stream=True, model=self.model)
|
||||||
try:
|
try:
|
||||||
while True: # loop until a non-empty string is found
|
while True: # loop until a non-empty string is found
|
||||||
# return this for all models
|
# return this for all models
|
||||||
completion_obj = {"content": ""}
|
completion_obj = {"content": ""}
|
||||||
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
|
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response_obj = self.handle_anthropic_chunk(chunk)
|
response_obj = self.handle_anthropic_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||||
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
|
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response_obj = self.handle_replicate_chunk(chunk)
|
response_obj = self.handle_replicate_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||||
elif (
|
elif (
|
||||||
self.custom_llm_provider and self.custom_llm_provider == "together_ai"):
|
self.custom_llm_provider and self.custom_llm_provider == "together_ai"):
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response_obj = self.handle_together_ai_chunk(chunk)
|
response_obj = self.handle_together_ai_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
|
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response_obj = self.handle_huggingface_chunk(chunk)
|
response_obj = self.handle_huggingface_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming
|
elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
completion_obj["content"] = self.handle_baseten_chunk(chunk)
|
completion_obj["content"] = self.handle_baseten_chunk(chunk)
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming
|
elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response_obj = self.handle_ai21_chunk(chunk)
|
response_obj = self.handle_ai21_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "azure":
|
elif self.custom_llm_provider and self.custom_llm_provider == "azure":
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response_obj = self.handle_azure_chunk(chunk)
|
response_obj = self.handle_azure_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
|
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response_obj = self.handle_maritalk_chunk(chunk)
|
response_obj = self.handle_maritalk_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "vllm":
|
elif self.custom_llm_provider and self.custom_llm_provider == "vllm":
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
completion_obj["content"] = chunk[0].outputs[0].text
|
completion_obj["content"] = chunk[0].outputs[0].text
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha": #aleph alpha doesn't provide streaming
|
elif self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha": #aleph alpha doesn't provide streaming
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response_obj = self.handle_aleph_alpha_chunk(chunk)
|
response_obj = self.handle_aleph_alpha_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||||
elif self.model in litellm.nlp_cloud_models or self.custom_llm_provider == "nlp_cloud":
|
elif self.model in litellm.nlp_cloud_models or self.custom_llm_provider == "nlp_cloud":
|
||||||
try:
|
try:
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response_obj = self.handle_nlp_cloud_chunk(chunk)
|
response_obj = self.handle_nlp_cloud_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
|
@ -4184,7 +4174,7 @@ class CustomStreamWrapper:
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai":
|
elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai":
|
||||||
try:
|
try:
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
completion_obj["content"] = str(chunk)
|
completion_obj["content"] = str(chunk)
|
||||||
except StopIteration as e:
|
except StopIteration as e:
|
||||||
if self.sent_last_chunk:
|
if self.sent_last_chunk:
|
||||||
|
@ -4193,13 +4183,11 @@ class CustomStreamWrapper:
|
||||||
model_response.choices[0].finish_reason = "stop"
|
model_response.choices[0].finish_reason = "stop"
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
elif self.custom_llm_provider == "cohere":
|
elif self.custom_llm_provider == "cohere":
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response_obj = self.handle_cohere_chunk(chunk)
|
response_obj = self.handle_cohere_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider == "bedrock":
|
elif self.custom_llm_provider == "bedrock":
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response_obj = self.handle_bedrock_stream(chunk)
|
response_obj = self.handle_bedrock_stream(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
|
@ -4242,19 +4230,16 @@ class CustomStreamWrapper:
|
||||||
self.completion_stream = self.completion_stream[chunk_size:]
|
self.completion_stream = self.completion_stream[chunk_size:]
|
||||||
time.sleep(0.05)
|
time.sleep(0.05)
|
||||||
elif self.custom_llm_provider == "ollama":
|
elif self.custom_llm_provider == "ollama":
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
if "error" in chunk:
|
if "error" in chunk:
|
||||||
exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=chunk["error"])
|
exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=chunk["error"])
|
||||||
completion_obj = chunk
|
completion_obj = chunk
|
||||||
elif self.custom_llm_provider == "openai":
|
elif self.custom_llm_provider == "openai":
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response_obj = self.handle_openai_chat_completion_chunk(chunk)
|
response_obj = self.handle_openai_chat_completion_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider == "text-completion-openai":
|
elif self.custom_llm_provider == "text-completion-openai":
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response_obj = self.handle_openai_text_completion_chunk(chunk)
|
response_obj = self.handle_openai_text_completion_chunk(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
|
@ -4267,7 +4252,7 @@ class CustomStreamWrapper:
|
||||||
if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string
|
if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string
|
||||||
hold, model_response_str = self.check_special_tokens(completion_obj["content"])
|
hold, model_response_str = self.check_special_tokens(completion_obj["content"])
|
||||||
if hold is False:
|
if hold is False:
|
||||||
completion_obj["content"] = model_response_str
|
completion_obj["content"] = model_response_str
|
||||||
if self.sent_first_chunk == False:
|
if self.sent_first_chunk == False:
|
||||||
completion_obj["role"] = "assistant"
|
completion_obj["role"] = "assistant"
|
||||||
self.sent_first_chunk = True
|
self.sent_first_chunk = True
|
||||||
|
@ -4275,11 +4260,15 @@ class CustomStreamWrapper:
|
||||||
# LOGGING
|
# LOGGING
|
||||||
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
|
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
|
||||||
return model_response
|
return model_response
|
||||||
|
else:
|
||||||
|
return
|
||||||
elif model_response.choices[0].finish_reason:
|
elif model_response.choices[0].finish_reason:
|
||||||
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
|
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
|
||||||
# LOGGING
|
# LOGGING
|
||||||
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
|
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
|
||||||
return model_response
|
return model_response
|
||||||
|
else:
|
||||||
|
return
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -4288,11 +4277,27 @@ class CustomStreamWrapper:
|
||||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||||
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
|
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
|
||||||
return exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e)
|
return exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e)
|
||||||
|
|
||||||
|
## needs to handle the empty string case (even starting chunk can be an empty string)
|
||||||
|
def __next__(self):
|
||||||
|
chunk = next(self.completion_stream)
|
||||||
|
return self.chunk_creator(chunk=chunk)
|
||||||
|
|
||||||
async def __anext__(self):
|
async def __anext__(self):
|
||||||
try:
|
try:
|
||||||
return next(self)
|
if self.custom_llm_provider == "openai":
|
||||||
except StopIteration:
|
async for chunk in self.completion_stream.content:
|
||||||
|
if chunk == "None" or chunk is None:
|
||||||
|
raise Exception
|
||||||
|
processed_chunk = self.chunk_creator(chunk=chunk)
|
||||||
|
if processed_chunk is None:
|
||||||
|
continue
|
||||||
|
return processed_chunk
|
||||||
|
raise StopAsyncIteration
|
||||||
|
else: # temporary patch for non-aiohttp async calls
|
||||||
|
return next(self)
|
||||||
|
except Exception as e:
|
||||||
|
# Handle any exceptions that might occur during streaming
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
|
|
||||||
class TextCompletionStreamWrapper:
|
class TextCompletionStreamWrapper:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue