mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(sagemaker.py): support model_id consistently. support dynamic args for async calls
This commit is contained in:
parent
d547944556
commit
109cd93a39
3 changed files with 165 additions and 90 deletions
|
@ -246,15 +246,28 @@ def completion(
|
|||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
model_id=model_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
return response
|
||||
response = client.invoke_endpoint_with_response_stream(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
|
||||
if model_id is not None:
|
||||
response = client.invoke_endpoint_with_response_stream(
|
||||
EndpointName=model,
|
||||
InferenceComponentName=model_id,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
else:
|
||||
response = client.invoke_endpoint_with_response_stream(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
return response["Body"]
|
||||
elif acompletion == True:
|
||||
_data = {"inputs": prompt, "parameters": inference_params}
|
||||
|
@ -265,31 +278,36 @@ def completion(
|
|||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
data=_data,
|
||||
model_id=model_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode(
|
||||
"utf-8"
|
||||
)
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName={model},
|
||||
ContentType="application/json",
|
||||
Body={data},
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
"hf_model_name": hf_model_name,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
if model_id is not None:
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName={model},
|
||||
InferenceComponentName={model_id},
|
||||
ContentType="application/json",
|
||||
Body={data},
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
"hf_model_name": hf_model_name,
|
||||
},
|
||||
)
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
InferenceComponentName=model_id,
|
||||
|
@ -298,6 +316,24 @@ def completion(
|
|||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
else:
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName={model},
|
||||
ContentType="application/json",
|
||||
Body={data},
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
"hf_model_name": hf_model_name,
|
||||
},
|
||||
)
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
|
@ -369,8 +405,12 @@ async def async_streaming(
|
|||
encoding,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
model_id: Optional[str],
|
||||
logging_obj: Any,
|
||||
data,
|
||||
aws_secret_access_key: Optional[str],
|
||||
aws_access_key_id: Optional[str],
|
||||
aws_region_name: Optional[str],
|
||||
):
|
||||
"""
|
||||
Use aioboto3
|
||||
|
@ -379,11 +419,6 @@ async def async_streaming(
|
|||
|
||||
session = aioboto3.Session()
|
||||
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
|
||||
if aws_access_key_id != None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
|
@ -410,12 +445,21 @@ async def async_streaming(
|
|||
|
||||
async with _client as client:
|
||||
try:
|
||||
response = await client.invoke_endpoint_with_response_stream(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
if model_id is not None:
|
||||
response = await client.invoke_endpoint_with_response_stream(
|
||||
EndpointName=model,
|
||||
InferenceComponentName=model_id,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
else:
|
||||
response = await client.invoke_endpoint_with_response_stream(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
except Exception as e:
|
||||
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||
response = response["Body"]
|
||||
|
@ -430,6 +474,10 @@ async def async_completion(
|
|||
model: str,
|
||||
logging_obj: Any,
|
||||
data: dict,
|
||||
model_id: Optional[str],
|
||||
aws_secret_access_key: Optional[str],
|
||||
aws_access_key_id: Optional[str],
|
||||
aws_region_name: Optional[str],
|
||||
):
|
||||
"""
|
||||
Use aioboto3
|
||||
|
@ -438,11 +486,6 @@ async def async_completion(
|
|||
|
||||
session = aioboto3.Session()
|
||||
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
|
||||
if aws_access_key_id != None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
|
@ -468,33 +511,63 @@ async def async_completion(
|
|||
)
|
||||
|
||||
async with _client as client:
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName={model},
|
||||
ContentType="application/json",
|
||||
Body={data},
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=data["inputs"],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
encoded_data = json.dumps(data).encode("utf-8")
|
||||
try:
|
||||
response = await client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=encoded_data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
if model_id is not None:
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName={model},
|
||||
InferenceComponentName={model_id},
|
||||
ContentType="application/json",
|
||||
Body={data},
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=data["inputs"],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response = await client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
InferenceComponentName=model_id,
|
||||
ContentType="application/json",
|
||||
Body=encoded_data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
else:
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName={model},
|
||||
ContentType="application/json",
|
||||
Body={data},
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
""" # type: ignore
|
||||
logging_obj.pre_call(
|
||||
input=data["inputs"],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response = await client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=encoded_data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
except Exception as e:
|
||||
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||
error_message = f"{str(e)}"
|
||||
if "Inference Component Name header is required" in error_message:
|
||||
error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`"
|
||||
raise SagemakerError(status_code=500, message=error_message)
|
||||
response = await response["Body"].read()
|
||||
response = response.decode("utf8")
|
||||
## LOGGING
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue