From 94f065f83c20f5ca2f9fb41c16f62a23d75c024d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 6 Dec 2023 11:41:00 -0800 Subject: [PATCH] feat(sagemaker.py): support huggingface embedding models --- litellm/llms/sagemaker.py | 133 +++++++++++++++++++++++++++++++- litellm/main.py | 10 +++ litellm/tests/test_embedding.py | 12 ++- 3 files changed, 150 insertions(+), 5 deletions(-) diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index cb5b56bddf..36324286b3 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -196,6 +196,133 @@ def completion( model_response.usage = usage return model_response -def embedding(): - # logic for parsing in - calling - parsing out model embedding calls - pass +def embedding(model: str, + input: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + custom_prompt_dict={}, + optional_params=None, + litellm_params=None, + logger_fn=None): + """ + Supports Huggingface Jumpstart embeddings like GPT-6B + """ + ### BOTO3 INIT + import boto3 + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_region_name = optional_params.pop("aws_region_name", None) + + if aws_access_key_id != None: + # uses auth params passed to completion + # aws_access_key_id is not None, assume user is trying to auth using litellm.completion + client = boto3.client( + service_name="sagemaker-runtime", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region_name, + ) + else: + # aws_access_key_id is None, assume user is trying to auth using env variables + # boto3 automaticaly reads env variables + + # we need to read region name from env + # I assume majority of users use .env for auth + region_name = ( + get_secret("AWS_REGION_NAME") or + "us-west-2" # default to us-west-2 if user not specified + ) + client = boto3.client( + service_name="sagemaker-runtime", + region_name=region_name, + ) + + # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker + inference_params = deepcopy(optional_params) + inference_params.pop("stream", None) + + ## Load Config + config = litellm.SagemakerConfig.get_config() + for k, v in config.items(): + if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + + #### HF EMBEDDING LOGIC + data = json.dumps({ + "text_inputs": input + }).encode('utf-8') + + ## LOGGING + request_str = f""" + response = client.invoke_endpoint( + EndpointName={model}, + ContentType="application/json", + Body={data}, + CustomAttributes="accept_eula=true", + )""" + logging_obj.pre_call( + input=input, + api_key="", + additional_args={"complete_input_dict": data, "request_str": request_str}, + ) + ## EMBEDDING CALL + try: + response = client.invoke_endpoint( + EndpointName=model, + ContentType="application/json", + Body=data, + CustomAttributes="accept_eula=true", + ) + except Exception as e: + raise SagemakerError(status_code=500, message=f"{str(e)}") + + ## LOGGING + logging_obj.post_call( + input=input, + api_key="", + additional_args={"complete_input_dict": data}, + original_response=response, + ) + + + response = json.loads(response["Body"].read().decode("utf8")) + ## LOGGING + logging_obj.post_call( + input=input, + api_key="", + original_response=response, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response}") + if "embedding" not in response: + raise SagemakerError(status_code=500, message="embedding not found in response") + embeddings = response['embedding'] + + if not isinstance(embeddings, list): + raise SagemakerError(status_code=422, message=f"Response not in expected format - {embeddings}") + + + output_data = [] + for idx, embedding in enumerate(embeddings): + output_data.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding + } + ) + + model_response["object"] = "list" + model_response["data"] = output_data + model_response["model"] = model + + input_tokens = 0 + for text in input: + input_tokens+=len(encoding.encode(text)) + + model_response["usage"] = Usage(prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens) + + return model_response diff --git a/litellm/main.py b/litellm/main.py index f265d4653b..d78c8ceb9c 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1906,6 +1906,16 @@ def embedding( optional_params=kwargs, model_response= EmbeddingResponse() ) + elif custom_llm_provider == "sagemaker": + response = sagemaker.embedding( + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=kwargs, + model_response= EmbeddingResponse(), + print_verbose=print_verbose + ) else: args = locals() raise ValueError(f"No valid embedding model args passed in - {args}") diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 628d2ed047..f958d6cfc4 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -214,7 +214,14 @@ def test_aembedding_azure(): # test_aembedding_azure() -# def test_custom_openai_embedding(): +def test_sagemaker_embeddings(): + try: + response = litellm.embedding(model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", input=["good morning from litellm", "this is another item"]) + print(f"response: {response}") + except Exception as e: + pytest.fail(f"Error occurred: {e}") +test_sagemaker_embeddings() +# def local_proxy_embeddings(): # litellm.set_verbose=True # response = embedding( # model="openai/custom_embedding", @@ -222,4 +229,5 @@ def test_aembedding_azure(): # api_base="http://0.0.0.0:8000/" # ) # print(response) -# test_custom_openai_embedding() + +# local_proxy_embeddings()