mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(sagemaker.py): filter out templated prompt if in model response
This commit is contained in:
parent
c9b83ff853
commit
a64bd2ca1e
2 changed files with 11 additions and 3 deletions
|
@ -158,6 +158,7 @@ def completion(
|
|||
)
|
||||
except Exception as e:
|
||||
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||
|
||||
response = response["Body"].read().decode("utf8")
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -171,10 +172,17 @@ def completion(
|
|||
completion_response = json.loads(response)
|
||||
try:
|
||||
completion_response_choices = completion_response[0]
|
||||
completion_output = ""
|
||||
if "generation" in completion_response_choices:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response_choices["generation"]
|
||||
completion_output += completion_response_choices["generation"]
|
||||
elif "generated_text" in completion_response_choices:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response_choices["generated_text"]
|
||||
completion_output += completion_response_choices["generated_text"]
|
||||
|
||||
# check if the prompt template is part of output, if so - filter it out
|
||||
if completion_output.startswith(prompt) and "<s>" in prompt:
|
||||
completion_output = completion_output.replace(prompt, "", 1)
|
||||
|
||||
model_response["choices"][0]["message"]["content"] = completion_output
|
||||
except:
|
||||
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue