mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(sagemaker.py): support huggingface embedding models
This commit is contained in:
parent
aefa4f36f9
commit
94f065f83c
3 changed files with 150 additions and 5 deletions
|
@ -196,6 +196,133 @@ def completion(
|
|||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
def embedding(model: str,
|
||||
input: list,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None):
|
||||
"""
|
||||
Supports Huggingface Jumpstart embeddings like GPT-6B
|
||||
"""
|
||||
### BOTO3 INIT
|
||||
import boto3
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
|
||||
if aws_access_key_id != None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
client = boto3.client(
|
||||
service_name="sagemaker-runtime",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automaticaly reads env variables
|
||||
|
||||
# we need to read region name from env
|
||||
# I assume majority of users use .env for auth
|
||||
region_name = (
|
||||
get_secret("AWS_REGION_NAME") or
|
||||
"us-west-2" # default to us-west-2 if user not specified
|
||||
)
|
||||
client = boto3.client(
|
||||
service_name="sagemaker-runtime",
|
||||
region_name=region_name,
|
||||
)
|
||||
|
||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||
inference_params = deepcopy(optional_params)
|
||||
inference_params.pop("stream", None)
|
||||
|
||||
## Load Config
|
||||
config = litellm.SagemakerConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
#### HF EMBEDDING LOGIC
|
||||
data = json.dumps({
|
||||
"text_inputs": input
|
||||
}).encode('utf-8')
|
||||
|
||||
## LOGGING
|
||||
request_str = f"""
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName={model},
|
||||
ContentType="application/json",
|
||||
Body={data},
|
||||
CustomAttributes="accept_eula=true",
|
||||
)"""
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||
)
|
||||
## EMBEDDING CALL
|
||||
try:
|
||||
response = client.invoke_endpoint(
|
||||
EndpointName=model,
|
||||
ContentType="application/json",
|
||||
Body=data,
|
||||
CustomAttributes="accept_eula=true",
|
||||
)
|
||||
except Exception as e:
|
||||
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
|
||||
response = json.loads(response["Body"].read().decode("utf8"))
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
if "embedding" not in response:
|
||||
raise SagemakerError(status_code=500, message="embedding not found in response")
|
||||
embeddings = response['embedding']
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
raise SagemakerError(status_code=422, message=f"Response not in expected format - {embeddings}")
|
||||
|
||||
|
||||
output_data = []
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding
|
||||
}
|
||||
)
|
||||
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = output_data
|
||||
model_response["model"] = model
|
||||
|
||||
input_tokens = 0
|
||||
for text in input:
|
||||
input_tokens+=len(encoding.encode(text))
|
||||
|
||||
model_response["usage"] = Usage(prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens)
|
||||
|
||||
return model_response
|
||||
|
|
|
@ -1906,6 +1906,16 @@ def embedding(
|
|||
optional_params=kwargs,
|
||||
model_response= EmbeddingResponse()
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
response = sagemaker.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=kwargs,
|
||||
model_response= EmbeddingResponse(),
|
||||
print_verbose=print_verbose
|
||||
)
|
||||
else:
|
||||
args = locals()
|
||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||
|
|
|
@ -214,7 +214,14 @@ def test_aembedding_azure():
|
|||
|
||||
# test_aembedding_azure()
|
||||
|
||||
# def test_custom_openai_embedding():
|
||||
def test_sagemaker_embeddings():
|
||||
try:
|
||||
response = litellm.embedding(model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", input=["good morning from litellm", "this is another item"])
|
||||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
test_sagemaker_embeddings()
|
||||
# def local_proxy_embeddings():
|
||||
# litellm.set_verbose=True
|
||||
# response = embedding(
|
||||
# model="openai/custom_embedding",
|
||||
|
@ -222,4 +229,5 @@ def test_aembedding_azure():
|
|||
# api_base="http://0.0.0.0:8000/"
|
||||
# )
|
||||
# print(response)
|
||||
# test_custom_openai_embedding()
|
||||
|
||||
# local_proxy_embeddings()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue