fix(sagemaker.py): support model_id consistently. support dynamic args for async calls

This commit is contained in:
Krrish Dholakia 2024-03-29 09:05:00 -07:00
parent d547944556
commit 109cd93a39
3 changed files with 165 additions and 90 deletions

View file

@ -246,15 +246,28 @@ def completion(
model=model,
logging_obj=logging_obj,
data=data,
model_id=model_id,
aws_secret_access_key=aws_secret_access_key,
aws_access_key_id=aws_access_key_id,
aws_region_name=aws_region_name,
)
return response
response = client.invoke_endpoint_with_response_stream(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
if model_id is not None:
response = client.invoke_endpoint_with_response_stream(
EndpointName=model,
InferenceComponentName=model_id,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
else:
response = client.invoke_endpoint_with_response_stream(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
return response["Body"]
elif acompletion == True:
_data = {"inputs": prompt, "parameters": inference_params}
@ -265,31 +278,36 @@ def completion(
model=model,
logging_obj=logging_obj,
data=_data,
model_id=model_id,
aws_secret_access_key=aws_secret_access_key,
aws_access_key_id=aws_access_key_id,
aws_region_name=aws_region_name,
)
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
"utf-8"
)
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)
""" # type: ignore
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
"hf_model_name": hf_model_name,
},
)
## COMPLETION CALL
try:
if model_id is not None:
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
InferenceComponentName={model_id},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)
""" # type: ignore
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
"hf_model_name": hf_model_name,
},
)
response = client.invoke_endpoint(
EndpointName=model,
InferenceComponentName=model_id,
@ -298,6 +316,24 @@ def completion(
CustomAttributes="accept_eula=true",
)
else:
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)
""" # type: ignore
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
"hf_model_name": hf_model_name,
},
)
response = client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
@ -369,8 +405,12 @@ async def async_streaming(
encoding,
model_response: ModelResponse,
model: str,
model_id: Optional[str],
logging_obj: Any,
data,
aws_secret_access_key: Optional[str],
aws_access_key_id: Optional[str],
aws_region_name: Optional[str],
):
"""
Use aioboto3
@ -379,11 +419,6 @@ async def async_streaming(
session = aioboto3.Session()
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
@ -410,12 +445,21 @@ async def async_streaming(
async with _client as client:
try:
response = await client.invoke_endpoint_with_response_stream(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
if model_id is not None:
response = await client.invoke_endpoint_with_response_stream(
EndpointName=model,
InferenceComponentName=model_id,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
else:
response = await client.invoke_endpoint_with_response_stream(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
response = response["Body"]
@ -430,6 +474,10 @@ async def async_completion(
model: str,
logging_obj: Any,
data: dict,
model_id: Optional[str],
aws_secret_access_key: Optional[str],
aws_access_key_id: Optional[str],
aws_region_name: Optional[str],
):
"""
Use aioboto3
@ -438,11 +486,6 @@ async def async_completion(
session = aioboto3.Session()
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
@ -468,33 +511,63 @@ async def async_completion(
)
async with _client as client:
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)
""" # type: ignore
logging_obj.pre_call(
input=data["inputs"],
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
)
encoded_data = json.dumps(data).encode("utf-8")
try:
response = await client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
Body=encoded_data,
CustomAttributes="accept_eula=true",
)
if model_id is not None:
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
InferenceComponentName={model_id},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)
""" # type: ignore
logging_obj.pre_call(
input=data["inputs"],
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
)
response = await client.invoke_endpoint(
EndpointName=model,
InferenceComponentName=model_id,
ContentType="application/json",
Body=encoded_data,
CustomAttributes="accept_eula=true",
)
else:
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)
""" # type: ignore
logging_obj.pre_call(
input=data["inputs"],
api_key="",
additional_args={
"complete_input_dict": data,
"request_str": request_str,
},
)
response = await client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
Body=encoded_data,
CustomAttributes="accept_eula=true",
)
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
error_message = f"{str(e)}"
if "Inference Component Name header is required" in error_message:
error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`"
raise SagemakerError(status_code=500, message=error_message)
response = await response["Body"].read()
response = response.decode("utf8")
## LOGGING