mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
(feat) completion:sagemaker - support chat models
This commit is contained in:
parent
cf73d1e3a4
commit
4bd13bc006
1 changed files with 11 additions and 14 deletions
|
@ -121,10 +121,10 @@ def completion(
|
||||||
else:
|
else:
|
||||||
prompt += f"{message['content']}"
|
prompt += f"{message['content']}"
|
||||||
|
|
||||||
data = {
|
data = json.dumps({
|
||||||
"inputs": prompt,
|
"inputs": prompt,
|
||||||
"parameters": inference_params
|
"parameters": inference_params
|
||||||
}
|
}).encode('utf-8')
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
request_str = f"""
|
request_str = f"""
|
||||||
|
@ -144,7 +144,7 @@ def completion(
|
||||||
response = client.invoke_endpoint(
|
response = client.invoke_endpoint(
|
||||||
EndpointName=model,
|
EndpointName=model,
|
||||||
ContentType="application/json",
|
ContentType="application/json",
|
||||||
Body=json.dumps(data),
|
Body=data,
|
||||||
CustomAttributes="accept_eula=true",
|
CustomAttributes="accept_eula=true",
|
||||||
)
|
)
|
||||||
response = response["Body"].read().decode("utf8")
|
response = response["Body"].read().decode("utf8")
|
||||||
|
@ -158,17 +158,14 @@ def completion(
|
||||||
print_verbose(f"raw model_response: {response}")
|
print_verbose(f"raw model_response: {response}")
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
completion_response = json.loads(response)
|
completion_response = json.loads(response)
|
||||||
if "error" in completion_response:
|
try:
|
||||||
raise SagemakerError(
|
completion_response_choices = completion_response[0]
|
||||||
message=completion_response["error"],
|
if "generation" in completion_response_choices:
|
||||||
status_code=response.status_code,
|
model_response["choices"][0]["message"]["content"] = completion_response_choices["generation"]
|
||||||
)
|
elif "generated_text" in completion_response_choices:
|
||||||
else:
|
model_response["choices"][0]["message"]["content"] = completion_response_choices["generated_text"]
|
||||||
try:
|
except:
|
||||||
if len(completion_response[0]["generation"]) > 0:
|
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)
|
||||||
model_response["choices"][0]["message"]["content"] = completion_response[0]["generation"]
|
|
||||||
except:
|
|
||||||
raise SagemakerError(message=json.dumps(completion_response), status_code=response.status_code)
|
|
||||||
|
|
||||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||||
prompt_tokens = len(
|
prompt_tokens = len(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue