mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
update baseten handler to handle TGI calls
This commit is contained in:
parent
a86e771c23
commit
4927e5879f
8 changed files with 79 additions and 38 deletions
|
@ -735,7 +735,8 @@ def get_optional_params( # use the openai defaults
|
|||
elif custom_llm_provider == "baseten":
|
||||
optional_params["temperature"] = temperature
|
||||
optional_params["stream"] = stream
|
||||
optional_params["top_p"] = top_p
|
||||
if top_p != 1:
|
||||
optional_params["top_p"] = top_p
|
||||
optional_params["top_k"] = top_k
|
||||
optional_params["num_beams"] = num_beams
|
||||
if max_tokens != float("inf"):
|
||||
|
@ -1739,18 +1740,31 @@ class CustomStreamWrapper:
|
|||
return chunk["choices"][0]["delta"]["content"]
|
||||
|
||||
def handle_baseten_chunk(self, chunk):
|
||||
chunk = chunk.decode("utf-8")
|
||||
data_json = json.loads(chunk)
|
||||
if "model_output" in data_json:
|
||||
if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list):
|
||||
return data_json["model_output"]["data"][0]
|
||||
elif isinstance(data_json["model_output"], str):
|
||||
return data_json["model_output"]
|
||||
elif "completion" in data_json and isinstance(data_json["completion"], str):
|
||||
return data_json["completion"]
|
||||
try:
|
||||
chunk = chunk.decode("utf-8")
|
||||
if len(chunk) > 0:
|
||||
if chunk.startswith("data:"):
|
||||
data_json = json.loads(chunk[5:])
|
||||
if "token" in data_json and "text" in data_json["token"]:
|
||||
return data_json["token"]["text"]
|
||||
else:
|
||||
return ""
|
||||
data_json = json.loads(chunk)
|
||||
if "model_output" in data_json:
|
||||
if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list):
|
||||
return data_json["model_output"]["data"][0]
|
||||
elif isinstance(data_json["model_output"], str):
|
||||
return data_json["model_output"]
|
||||
elif "completion" in data_json and isinstance(data_json["completion"], str):
|
||||
return data_json["completion"]
|
||||
else:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
else:
|
||||
return ""
|
||||
else:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
else:
|
||||
return ""
|
||||
except:
|
||||
traceback.print_exc()
|
||||
return ""
|
||||
|
||||
def __next__(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue