From f7ae01da8a19011f09a87df23b876d1c4689f9db Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 27 Nov 2023 10:11:10 -0800 Subject: [PATCH] (feat) completion:sagemaker - support chat models --- litellm/llms/sagemaker.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index ce36367de0..c0bf7cd341 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -121,10 +121,10 @@ def completion( else: prompt += f"{message['content']}" - data = { + data = json.dumps({ "inputs": prompt, "parameters": inference_params - } + }).encode('utf-8') ## LOGGING request_str = f""" @@ -144,7 +144,7 @@ def completion( response = client.invoke_endpoint( EndpointName=model, ContentType="application/json", - Body=json.dumps(data), + Body=data, CustomAttributes="accept_eula=true", ) response = response["Body"].read().decode("utf8") @@ -158,17 +158,14 @@ def completion( print_verbose(f"raw model_response: {response}") ## RESPONSE OBJECT completion_response = json.loads(response) - if "error" in completion_response: - raise SagemakerError( - message=completion_response["error"], - status_code=response.status_code, - ) - else: - try: - if len(completion_response[0]["generation"]) > 0: - model_response["choices"][0]["message"]["content"] = completion_response[0]["generation"] - except: - raise SagemakerError(message=json.dumps(completion_response), status_code=response.status_code) + try: + completion_response_choices = completion_response[0] + if "generation" in completion_response_choices: + model_response["choices"][0]["message"]["content"] = completion_response_choices["generation"] + elif "generated_text" in completion_response_choices: + model_response["choices"][0]["message"]["content"] = completion_response_choices["generated_text"] + except: + raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500) ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. prompt_tokens = len(