(feat) completion:sagemaker - support chat models

This commit is contained in:
ishaan-jaff 2023-11-27 10:11:10 -08:00
parent b732d4c394
commit f7ae01da8a

View file

@ -121,10 +121,10 @@ def completion(
else:
prompt += f"{message['content']}"
data = {
data = json.dumps({
"inputs": prompt,
"parameters": inference_params
}
}).encode('utf-8')
## LOGGING
request_str = f"""
@ -144,7 +144,7 @@ def completion(
response = client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
Body=json.dumps(data),
Body=data,
CustomAttributes="accept_eula=true",
)
response = response["Body"].read().decode("utf8")
@ -158,17 +158,14 @@ def completion(
print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT
completion_response = json.loads(response)
if "error" in completion_response:
raise SagemakerError(
message=completion_response["error"],
status_code=response.status_code,
)
else:
try:
if len(completion_response[0]["generation"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response[0]["generation"]
except:
raise SagemakerError(message=json.dumps(completion_response), status_code=response.status_code)
try:
completion_response_choices = completion_response[0]
if "generation" in completion_response_choices:
model_response["choices"][0]["message"]["content"] = completion_response_choices["generation"]
elif "generated_text" in completion_response_choices:
model_response["choices"][0]["message"]["content"] = completion_response_choices["generated_text"]
except:
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(