feat(sagemaker.py): support huggingface embedding models

This commit is contained in:
Krrish Dholakia 2023-12-06 11:41:00 -08:00
parent aefa4f36f9
commit 94f065f83c
3 changed files with 150 additions and 5 deletions

View file

@ -196,6 +196,133 @@ def completion(
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass
def embedding(model: str,
input: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None):
"""
Supports Huggingface Jumpstart embeddings like GPT-6B
"""
### BOTO3 INIT
import boto3
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
client = boto3.client(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables
# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME") or
"us-west-2" # default to us-west-2 if user not specified
)
client = boto3.client(
service_name="sagemaker-runtime",
region_name=region_name,
)
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
inference_params = deepcopy(optional_params)
inference_params.pop("stream", None)
## Load Config
config = litellm.SagemakerConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
#### HF EMBEDDING LOGIC
data = json.dumps({
"text_inputs": input
}).encode('utf-8')
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
EndpointName={model},
ContentType="application/json",
Body={data},
CustomAttributes="accept_eula=true",
)"""
logging_obj.pre_call(
input=input,
api_key="",
additional_args={"complete_input_dict": data, "request_str": request_str},
)
## EMBEDDING CALL
try:
response = client.invoke_endpoint(
EndpointName=model,
ContentType="application/json",
Body=data,
CustomAttributes="accept_eula=true",
)
except Exception as e:
raise SagemakerError(status_code=500, message=f"{str(e)}")
## LOGGING
logging_obj.post_call(
input=input,
api_key="",
additional_args={"complete_input_dict": data},
original_response=response,
)
response = json.loads(response["Body"].read().decode("utf8"))
## LOGGING
logging_obj.post_call(
input=input,
api_key="",
original_response=response,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response}")
if "embedding" not in response:
raise SagemakerError(status_code=500, message="embedding not found in response")
embeddings = response['embedding']
if not isinstance(embeddings, list):
raise SagemakerError(status_code=422, message=f"Response not in expected format - {embeddings}")
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding
}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
input_tokens = 0
for text in input:
input_tokens+=len(encoding.encode(text))
model_response["usage"] = Usage(prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens)
return model_response

View file

@ -1906,6 +1906,16 @@ def embedding(
optional_params=kwargs,
model_response= EmbeddingResponse()
)
elif custom_llm_provider == "sagemaker":
response = sagemaker.embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=kwargs,
model_response= EmbeddingResponse(),
print_verbose=print_verbose
)
else:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")

View file

@ -214,7 +214,14 @@ def test_aembedding_azure():
# test_aembedding_azure()
# def test_custom_openai_embedding():
def test_sagemaker_embeddings():
try:
response = litellm.embedding(model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", input=["good morning from litellm", "this is another item"])
print(f"response: {response}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_sagemaker_embeddings()
# def local_proxy_embeddings():
# litellm.set_verbose=True
# response = embedding(
# model="openai/custom_embedding",
@ -222,4 +229,5 @@ def test_aembedding_azure():
# api_base="http://0.0.0.0:8000/"
# )
# print(response)
# test_custom_openai_embedding()
# local_proxy_embeddings()