update baseten handler to handle TGI calls

This commit is contained in:
Krrish Dholakia 2023-08-30 19:14:48 -07:00
parent a86e771c23
commit 4927e5879f
8 changed files with 79 additions and 38 deletions

View file

@ -60,9 +60,12 @@ class BasetenLLM:
else:
prompt += f"{message['content']}"
data = {
"prompt": prompt,
# "prompt": prompt,
"inputs": prompt, # in case it's a TGI deployed model
# "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg
**optional_params,
# **optional_params,
"parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False
}
## LOGGING
@ -76,8 +79,9 @@ class BasetenLLM:
self.completion_url_fragment_1 + model + self.completion_url_fragment_2,
headers=self.headers,
data=json.dumps(data),
stream=True if "stream" in optional_params and optional_params["stream"] == True else False
)
if "stream" in optional_params and optional_params["stream"] == True:
if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True):
return response.iter_lines()
else:
## LOGGING
@ -117,9 +121,23 @@ class BasetenLLM:
model_response["choices"][0]["message"][
"content"
] = completion_response["completion"]
elif isinstance(completion_response, list) and len(completion_response) > 0:
if "generated_text" not in completion_response:
raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code
)
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
## GETTING LOGPROBS
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
else:
raise ValueError(
f"Unable to parse response. Original response: {response.text}"
raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.