(feat) completion:sagemaker - support chat models

This commit is contained in:
ishaan-jaff 2023-11-27 10:11:10 -08:00
parent cf73d1e3a4
commit 4bd13bc006

View file

@ -121,10 +121,10 @@ def completion(
else: else:
prompt += f"{message['content']}" prompt += f"{message['content']}"
data = { data = json.dumps({
"inputs": prompt, "inputs": prompt,
"parameters": inference_params "parameters": inference_params
} }).encode('utf-8')
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
@ -144,7 +144,7 @@ def completion(
response = client.invoke_endpoint( response = client.invoke_endpoint(
EndpointName=model, EndpointName=model,
ContentType="application/json", ContentType="application/json",
Body=json.dumps(data), Body=data,
CustomAttributes="accept_eula=true", CustomAttributes="accept_eula=true",
) )
response = response["Body"].read().decode("utf8") response = response["Body"].read().decode("utf8")
@ -158,17 +158,14 @@ def completion(
print_verbose(f"raw model_response: {response}") print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT ## RESPONSE OBJECT
completion_response = json.loads(response) completion_response = json.loads(response)
if "error" in completion_response: try:
raise SagemakerError( completion_response_choices = completion_response[0]
message=completion_response["error"], if "generation" in completion_response_choices:
status_code=response.status_code, model_response["choices"][0]["message"]["content"] = completion_response_choices["generation"]
) elif "generated_text" in completion_response_choices:
else: model_response["choices"][0]["message"]["content"] = completion_response_choices["generated_text"]
try: except:
if len(completion_response[0]["generation"]) > 0: raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)
model_response["choices"][0]["message"]["content"] = completion_response[0]["generation"]
except:
raise SagemakerError(message=json.dumps(completion_response), status_code=response.status_code)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len( prompt_tokens = len(