fix(sagemaker.py): filter out templated prompt if in model response

This commit is contained in:
Krrish Dholakia 2023-12-13 07:43:33 -08:00
parent c9b83ff853
commit a64bd2ca1e
2 changed files with 11 additions and 3 deletions

View file

@ -158,6 +158,7 @@ def completion(
)
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
response = response["Body"].read().decode("utf8")
## LOGGING
logging_obj.post_call(
@ -171,10 +172,17 @@ def completion(
completion_response = json.loads(response)
try:
completion_response_choices = completion_response[0]
completion_output = ""
if "generation" in completion_response_choices:
model_response["choices"][0]["message"]["content"] = completion_response_choices["generation"]
completion_output += completion_response_choices["generation"]
elif "generated_text" in completion_response_choices:
model_response["choices"][0]["message"]["content"] = completion_response_choices["generated_text"]
completion_output += completion_response_choices["generated_text"]
# check if the prompt template is part of output, if so - filter it out
if completion_output.startswith(prompt) and "<s>" in prompt:
completion_output = completion_output.replace(prompt, "", 1)
model_response["choices"][0]["message"]["content"] = completion_output
except:
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)

View file

@ -298,7 +298,7 @@ async def test_async_custom_handler_embedding_optional_param():
customHandler_optional_params = MyCustomHandler()
litellm.callbacks = [customHandler_optional_params]
response = await litellm.aembedding(
model="text-embedding-ada-002",
model="azure/azure-embedding-model",
input = ["hello world"],
user = "John"
)