fix streaming formatting for non-openai models

This commit is contained in:
Krrish Dholakia 2023-09-16 19:20:07 -07:00
parent a63784d5b3
commit e44c218c1b
6 changed files with 16 additions and 14 deletions

View file

@ -2510,11 +2510,11 @@ class CustomStreamWrapper:
model_response = ModelResponse(stream=True, model=self.model)
try:
# return this for all models
completion_obj = {"content": ""}
if self.sent_first_chunk == False:
model_response.choices[0].delta.role = "assistant"
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
completion_obj = {"content": ""} # default to role being assistant
if self.model in litellm.anthropic_models:
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_anthropic_chunk(chunk)
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
@ -2539,10 +2539,10 @@ class CustomStreamWrapper:
elif self.custom_llm_provider and self.custom_llm_provider == "vllm":
chunk = next(self.completion_stream)
completion_obj["content"] = chunk[0].outputs[0].text
elif self.model in litellm.aleph_alpha_models: #aleph alpha doesn't provide streaming
elif self.custom_llm_provider and self.custom_llm_provider == "aleph-alpha": #aleph alpha doesn't provide streaming
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_aleph_alpha_chunk(chunk)
elif self.model in litellm.open_ai_text_completion_models:
elif self.custom_llm_provider and self.custom_llm_provider == "text-completion-openai":
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_openai_text_completion_chunk(chunk)
elif self.model in litellm.nlp_cloud_models or self.custom_llm_provider == "nlp_cloud":
@ -2551,7 +2551,7 @@ class CustomStreamWrapper:
elif self.model in (litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models):
chunk = next(self.completion_stream)
completion_obj["content"] = str(chunk)
elif self.model in litellm.cohere_models or self.custom_llm_provider == "cohere":
elif self.custom_llm_provider == "cohere":
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_cohere_chunk(chunk)
elif self.custom_llm_provider == "bedrock":
@ -2566,7 +2566,8 @@ class CustomStreamWrapper:
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(completion_obj,)).start()
model_response.model = self.model
model_response.choices[0].delta["content"] = completion_obj["content"]
if len(completion_obj["content"]) > 0: # cannot set content of an OpenAI Object to be an empty string
model_response.choices[0].delta = Delta(**completion_obj)
return model_response
except StopIteration:
raise StopIteration