update baseten handler to handle TGI calls

This commit is contained in:
Krrish Dholakia 2023-08-30 19:14:48 -07:00
parent a86e771c23
commit 4927e5879f
8 changed files with 79 additions and 38 deletions

View file

@ -735,7 +735,8 @@ def get_optional_params( # use the openai defaults
elif custom_llm_provider == "baseten":
optional_params["temperature"] = temperature
optional_params["stream"] = stream
optional_params["top_p"] = top_p
if top_p != 1:
optional_params["top_p"] = top_p
optional_params["top_k"] = top_k
optional_params["num_beams"] = num_beams
if max_tokens != float("inf"):
@ -1739,18 +1740,31 @@ class CustomStreamWrapper:
return chunk["choices"][0]["delta"]["content"]
def handle_baseten_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
if "model_output" in data_json:
if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list):
return data_json["model_output"]["data"][0]
elif isinstance(data_json["model_output"], str):
return data_json["model_output"]
elif "completion" in data_json and isinstance(data_json["completion"], str):
return data_json["completion"]
try:
chunk = chunk.decode("utf-8")
if len(chunk) > 0:
if chunk.startswith("data:"):
data_json = json.loads(chunk[5:])
if "token" in data_json and "text" in data_json["token"]:
return data_json["token"]["text"]
else:
return ""
data_json = json.loads(chunk)
if "model_output" in data_json:
if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list):
return data_json["model_output"]["data"][0]
elif isinstance(data_json["model_output"], str):
return data_json["model_output"]
elif "completion" in data_json and isinstance(data_json["completion"], str):
return data_json["completion"]
else:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
else:
return ""
else:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
else:
return ""
except:
traceback.print_exc()
return ""
def __next__(self):