mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fixes to streaming for ai21, baseten, and openai text completions
This commit is contained in:
parent
d11cb3e2ea
commit
d542066d4b
9 changed files with 273 additions and 117 deletions
|
@ -648,6 +648,7 @@ def get_optional_params( # use the openai defaults
|
|||
optional_params["top_k"] = top_k
|
||||
elif custom_llm_provider == "baseten":
|
||||
optional_params["temperature"] = temperature
|
||||
optional_params["stream"] = stream
|
||||
optional_params["top_p"] = top_p
|
||||
optional_params["top_k"] = top_k
|
||||
optional_params["num_beams"] = num_beams
|
||||
|
@ -1561,6 +1562,35 @@ class CustomStreamWrapper:
|
|||
else:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
def handle_ai21_chunk(self, chunk):
|
||||
chunk = chunk.decode("utf-8")
|
||||
data_json = json.loads(chunk)
|
||||
try:
|
||||
return data_json["completions"][0]["data"]["text"]
|
||||
except:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
|
||||
def handle_openai_text_completion_chunk(self, chunk):
|
||||
try:
|
||||
return chunk["choices"][0]["text"]
|
||||
except:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
|
||||
def handle_baseten_chunk(self, chunk):
|
||||
chunk = chunk.decode("utf-8")
|
||||
data_json = json.loads(chunk)
|
||||
if "model_output" in data_json:
|
||||
if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list):
|
||||
return data_json["model_output"]["data"][0]
|
||||
elif isinstance(data_json["model_output"], str):
|
||||
return data_json["model_output"]
|
||||
elif "completion" in data_json and isinstance(data_json["completion"], str):
|
||||
return data_json["completion"]
|
||||
else:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
else:
|
||||
return ""
|
||||
|
||||
def __next__(self):
|
||||
completion_obj = {"role": "assistant", "content": ""}
|
||||
|
@ -1584,6 +1614,15 @@ class CustomStreamWrapper:
|
|||
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
|
||||
chunk = next(self.completion_stream)
|
||||
completion_obj["content"] = self.handle_huggingface_chunk(chunk)
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming
|
||||
chunk = next(self.completion_stream)
|
||||
completion_obj["content"] = self.handle_baseten_chunk(chunk)
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming
|
||||
chunk = next(self.completion_stream)
|
||||
completion_obj["content"] = self.handle_ai21_chunk(chunk)
|
||||
elif self.model in litellm.open_ai_text_completion_models:
|
||||
chunk = next(self.completion_stream)
|
||||
completion_obj["content"] = self.handle_openai_text_completion_chunk(chunk)
|
||||
# return this for all models
|
||||
return {"choices": [{"delta": completion_obj}]}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue