feat(sagemaker.py): support huggingface embedding models

This commit is contained in:
Krrish Dholakia 2023-12-06 11:41:00 -08:00
parent aefa4f36f9
commit 94f065f83c
3 changed files with 150 additions and 5 deletions

View file

@ -196,6 +196,133 @@ def completion(
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass
def embedding(model: str,
input: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None):
"""
Supports Huggingface Jumpstart embeddings like GPT-6B
"""
### BOTO3 INIT
import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
client = boto3.client(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables
# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME") or
"us-west-2" # default to us-west-2 if user not specified
)
client = boto3.client(
service_name="sagemaker-runtime",
region_name=region_name,
)
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params)
inference_params.pop("stream", None)
## Load Config
config = litellm.SagemakerConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
#### HF EMBEDDING LOGIC
data = json.dumps({
"text_inputs": input
}).encode('utf-8')
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)"""
logging_obj.pre_call(
input=input,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
)
## EMBEDDING CALL
try:
response = client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
## LOGGING
logging_obj.post_call(
input=input,
api_key="",
additional_args={"complete_input_dict": data},
original_response=response,
)
response = json.loads(response["Body"].read().decode("utf8"))
## LOGGING
logging_obj.post_call(
input=input,
api_key="",
original_response=response,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response}")
if "embedding" not in response:
raise SagemakerError(status_code=500, message="embedding not found in response")
embeddings = response['embedding']
if not isinstance(embeddings, list):
raise SagemakerError(status_code=422, message=f"Response not in expected format - {embeddings}")
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding
}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
input_tokens = 0
for text in input:
input_tokens+=len(encoding.encode(text))
model_response["usage"] = Usage(prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens)
return model_response