mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
update baseten handler to handle TGI calls
This commit is contained in:
parent
a86e771c23
commit
4927e5879f
8 changed files with 79 additions and 38 deletions
|
@ -60,9 +60,12 @@ class BasetenLLM:
|
|||
else:
|
||||
prompt += f"{message['content']}"
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
# "prompt": prompt,
|
||||
"inputs": prompt, # in case it's a TGI deployed model
|
||||
# "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg
|
||||
**optional_params,
|
||||
# **optional_params,
|
||||
"parameters": optional_params,
|
||||
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
|
@ -76,8 +79,9 @@ class BasetenLLM:
|
|||
self.completion_url_fragment_1 + model + self.completion_url_fragment_2,
|
||||
headers=self.headers,
|
||||
data=json.dumps(data),
|
||||
stream=True if "stream" in optional_params and optional_params["stream"] == True else False
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True):
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -117,9 +121,23 @@ class BasetenLLM:
|
|||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["completion"]
|
||||
elif isinstance(completion_response, list) and len(completion_response) > 0:
|
||||
if "generated_text" not in completion_response:
|
||||
raise BasetenError(
|
||||
message=f"Unable to parse response. Original response: {response.text}",
|
||||
status_code=response.status_code
|
||||
)
|
||||
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
|
||||
## GETTING LOGPROBS
|
||||
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
|
||||
sum_logprob = 0
|
||||
for token in completion_response[0]["details"]["tokens"]:
|
||||
sum_logprob += token["logprob"]
|
||||
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to parse response. Original response: {response.text}"
|
||||
raise BasetenError(
|
||||
message=f"Unable to parse response. Original response: {response.text}",
|
||||
status_code=response.status_code
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue