fixes to streaming for ai21, baseten, and openai text completions

This commit is contained in:
Krrish Dholakia 2023-08-28 09:38:40 -07:00
parent d11cb3e2ea
commit d542066d4b
9 changed files with 273 additions and 117 deletions

View file

@ -648,6 +648,7 @@ def get_optional_params( # use the openai defaults
optional_params["top_k"] = top_k
elif custom_llm_provider == "baseten":
optional_params["temperature"] = temperature
optional_params["stream"] = stream
optional_params["top_p"] = top_p
optional_params["top_k"] = top_k
optional_params["num_beams"] = num_beams
@ -1561,6 +1562,35 @@ class CustomStreamWrapper:
else:
return ""
return ""
def handle_ai21_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
try:
return data_json["completions"][0]["data"]["text"]
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_openai_text_completion_chunk(self, chunk):
try:
return chunk["choices"][0]["text"]
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_baseten_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
if "model_output" in data_json:
if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list):
return data_json["model_output"]["data"][0]
elif isinstance(data_json["model_output"], str):
return data_json["model_output"]
elif "completion" in data_json and isinstance(data_json["completion"], str):
return data_json["completion"]
else:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
else:
return ""
def __next__(self):
completion_obj = {"role": "assistant", "content": ""}
@ -1584,6 +1614,15 @@ class CustomStreamWrapper:
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_huggingface_chunk(chunk)
elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_baseten_chunk(chunk)
elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_ai21_chunk(chunk)
elif self.model in litellm.open_ai_text_completion_models:
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_openai_text_completion_chunk(chunk)
# return this for all models
return {"choices": [{"delta": completion_obj}]}