mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
(feat) completion:sagemaker - support chat models
This commit is contained in:
parent
b732d4c394
commit
f7ae01da8a
1 changed files with 11 additions and 14 deletions
|
@ -121,10 +121,10 @@ def completion(
|
|||
else:
|
||||
prompt += f"{message['content']}"
|
||||
|
||||
data = {
|
||||
data = json.dumps({
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params
|
||||
}
|
||||
}).encode('utf-8')
|
||||
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
|
@ -144,7 +144,7 @@ def completion(
|
|||
response = client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=json.dumps(data),
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
response = response["Body"].read().decode("utf8")
|
||||
|
@ -158,17 +158,14 @@ def completion(
|
|||
print_verbose(f"raw model_response: {response}")
|
||||
## RESPONSE OBJECT
|
||||
completion_response = json.loads(response)
|
||||
if "error" in completion_response:
|
||||
raise SagemakerError(
|
||||
message=completion_response["error"],
|
||||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
if len(completion_response[0]["generation"]) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response[0]["generation"]
|
||||
except:
|
||||
raise SagemakerError(message=json.dumps(completion_response), status_code=response.status_code)
|
||||
try:
|
||||
completion_response_choices = completion_response[0]
|
||||
if "generation" in completion_response_choices:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response_choices["generation"]
|
||||
elif "generated_text" in completion_response_choices:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response_choices["generated_text"]
|
||||
except:
|
||||
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue