refactor(openai.py): support aiohttp streaming

This commit is contained in:
Krrish Dholakia 2023-11-09 16:15:21 -08:00
parent bba62b56d3
commit c053782d96
5 changed files with 108 additions and 42 deletions

View file

@ -3998,14 +3998,15 @@ class CustomStreamWrapper:
text = ""
is_finished = False
finish_reason = None
if str_line == "data: [DONE]":
if "data: [DONE]" in str_line:
# anyscale returns a [DONE] special char for streaming, this cannot be json loaded. This is the end of stream
text = ""
is_finished = True
finish_reason = "stop"
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
elif str_line.startswith("data:"):
data_json = json.loads(str_line[5:])
elif str_line.startswith("data:") and len(str_line[5:]) > 0:
str_line = str_line[5:]
data_json = json.loads(str_line)
print_verbose(f"delta content: {data_json['choices'][0]['delta']}")
text = data_json["choices"][0]["delta"].get("content", "")
if data_json["choices"][0].get("finish_reason", None):
@ -4104,72 +4105,61 @@ class CustomStreamWrapper:
raise Exception(chunk["error"])
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
return ""
## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self):
def chunk_creator(self, chunk):
model_response = ModelResponse(stream=True, model=self.model)
try:
while True: # loop until a non-empty string is found
# return this for all models
completion_obj = {"content": ""}
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
chunk = next(self.completion_stream)
response_obj = self.handle_anthropic_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
chunk = next(self.completion_stream)
response_obj = self.handle_replicate_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif (
self.custom_llm_provider and self.custom_llm_provider == "together_ai"):
chunk = next(self.completion_stream)
response_obj = self.handle_together_ai_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
chunk = next(self.completion_stream)
response_obj = self.handle_huggingface_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_baseten_chunk(chunk)
elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming
chunk = next(self.completion_stream)
response_obj = self.handle_ai21_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "azure":
chunk = next(self.completion_stream)
response_obj = self.handle_azure_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
chunk = next(self.completion_stream)
response_obj = self.handle_maritalk_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "vllm":
chunk = next(self.completion_stream)
completion_obj["content"] = chunk[0].outputs[0].text
elif self.custom_llm_provider and self.custom_llm_provider == "aleph_alpha": #aleph alpha doesn't provide streaming
chunk = next(self.completion_stream)
response_obj = self.handle_aleph_alpha_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.model in litellm.nlp_cloud_models or self.custom_llm_provider == "nlp_cloud":
try:
chunk = next(self.completion_stream)
response_obj = self.handle_nlp_cloud_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
@ -4184,7 +4174,7 @@ class CustomStreamWrapper:
self.sent_last_chunk = True
elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai":
try:
chunk = next(self.completion_stream)
completion_obj["content"] = str(chunk)
except StopIteration as e:
if self.sent_last_chunk:
@ -4193,13 +4183,11 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
elif self.custom_llm_provider == "cohere":
chunk = next(self.completion_stream)
response_obj = self.handle_cohere_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "bedrock":
chunk = next(self.completion_stream)
response_obj = self.handle_bedrock_stream(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
@ -4242,19 +4230,16 @@ class CustomStreamWrapper:
self.completion_stream = self.completion_stream[chunk_size:]
time.sleep(0.05)
elif self.custom_llm_provider == "ollama":
chunk = next(self.completion_stream)
if "error" in chunk:
exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=chunk["error"])
completion_obj = chunk
elif self.custom_llm_provider == "openai":
chunk = next(self.completion_stream)
response_obj = self.handle_openai_chat_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai":
chunk = next(self.completion_stream)
response_obj = self.handle_openai_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
@ -4267,7 +4252,7 @@ class CustomStreamWrapper:
if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string
hold, model_response_str = self.check_special_tokens(completion_obj["content"])
if hold is False:
completion_obj["content"] = model_response_str
completion_obj["content"] = model_response_str
if self.sent_first_chunk == False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
@ -4275,11 +4260,15 @@ class CustomStreamWrapper:
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
return model_response
else:
return
elif model_response.choices[0].finish_reason:
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
return model_response
else:
return
except StopIteration:
raise StopIteration
except Exception as e:
@ -4288,11 +4277,27 @@ class CustomStreamWrapper:
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
return exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e)
## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self):
chunk = next(self.completion_stream)
return self.chunk_creator(chunk=chunk)
async def __anext__(self):
try:
return next(self)
except StopIteration:
if self.custom_llm_provider == "openai":
async for chunk in self.completion_stream.content:
if chunk == "None" or chunk is None:
raise Exception
processed_chunk = self.chunk_creator(chunk=chunk)
if processed_chunk is None:
continue
return processed_chunk
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls
return next(self)
except Exception as e:
# Handle any exceptions that might occur during streaming
raise StopAsyncIteration
class TextCompletionStreamWrapper: