forked from phoenix/litellm-mirror
fix(sagemaker.py): filter out templated prompt if in model response
This commit is contained in:
parent
c9b83ff853
commit
a64bd2ca1e
2 changed files with 11 additions and 3 deletions
|
@ -158,6 +158,7 @@ def completion(
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||||
|
|
||||||
response = response["Body"].read().decode("utf8")
|
response = response["Body"].read().decode("utf8")
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -171,10 +172,17 @@ def completion(
|
||||||
completion_response = json.loads(response)
|
completion_response = json.loads(response)
|
||||||
try:
|
try:
|
||||||
completion_response_choices = completion_response[0]
|
completion_response_choices = completion_response[0]
|
||||||
|
completion_output = ""
|
||||||
if "generation" in completion_response_choices:
|
if "generation" in completion_response_choices:
|
||||||
model_response["choices"][0]["message"]["content"] = completion_response_choices["generation"]
|
completion_output += completion_response_choices["generation"]
|
||||||
elif "generated_text" in completion_response_choices:
|
elif "generated_text" in completion_response_choices:
|
||||||
model_response["choices"][0]["message"]["content"] = completion_response_choices["generated_text"]
|
completion_output += completion_response_choices["generated_text"]
|
||||||
|
|
||||||
|
# check if the prompt template is part of output, if so - filter it out
|
||||||
|
if completion_output.startswith(prompt) and "<s>" in prompt:
|
||||||
|
completion_output = completion_output.replace(prompt, "", 1)
|
||||||
|
|
||||||
|
model_response["choices"][0]["message"]["content"] = completion_output
|
||||||
except:
|
except:
|
||||||
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)
|
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)
|
||||||
|
|
||||||
|
|
|
@ -298,7 +298,7 @@ async def test_async_custom_handler_embedding_optional_param():
|
||||||
customHandler_optional_params = MyCustomHandler()
|
customHandler_optional_params = MyCustomHandler()
|
||||||
litellm.callbacks = [customHandler_optional_params]
|
litellm.callbacks = [customHandler_optional_params]
|
||||||
response = await litellm.aembedding(
|
response = await litellm.aembedding(
|
||||||
model="text-embedding-ada-002",
|
model="azure/azure-embedding-model",
|
||||||
input = ["hello world"],
|
input = ["hello world"],
|
||||||
user = "John"
|
user = "John"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue