diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 96d06bc20..2bfa9f82a 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -158,6 +158,7 @@ def completion( ) except Exception as e: raise SagemakerError(status_code=500, message=f"{str(e)}") + response = response["Body"].read().decode("utf8") ## LOGGING logging_obj.post_call( @@ -171,10 +172,17 @@ def completion( completion_response = json.loads(response) try: completion_response_choices = completion_response[0] + completion_output = "" if "generation" in completion_response_choices: - model_response["choices"][0]["message"]["content"] = completion_response_choices["generation"] + completion_output += completion_response_choices["generation"] elif "generated_text" in completion_response_choices: - model_response["choices"][0]["message"]["content"] = completion_response_choices["generated_text"] + completion_output += completion_response_choices["generated_text"] + + # check if the prompt template is part of output, if so - filter it out + if completion_output.startswith(prompt) and "" in prompt: + completion_output = completion_output.replace(prompt, "", 1) + + model_response["choices"][0]["message"]["content"] = completion_output except: raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500) diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index 8209520e4..c50f2b1ee 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -298,7 +298,7 @@ async def test_async_custom_handler_embedding_optional_param(): customHandler_optional_params = MyCustomHandler() litellm.callbacks = [customHandler_optional_params] response = await litellm.aembedding( - model="text-embedding-ada-002", + model="azure/azure-embedding-model", input = ["hello world"], user = "John" )