mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(sagemaker.py): support huggingface embedding models
This commit is contained in:
parent
aefa4f36f9
commit
94f065f83c
3 changed files with 150 additions and 5 deletions
|
@ -196,6 +196,133 @@ def completion(
|
||||||
model_response.usage = usage
|
model_response.usage = usage
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
def embedding():
|
def embedding(model: str,
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
input: list,
|
||||||
pass
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
custom_prompt_dict={},
|
||||||
|
optional_params=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None):
|
||||||
|
"""
|
||||||
|
Supports Huggingface Jumpstart embeddings like GPT-6B
|
||||||
|
"""
|
||||||
|
### BOTO3 INIT
|
||||||
|
import boto3
|
||||||
|
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||||
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||||
|
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||||
|
|
||||||
|
if aws_access_key_id != None:
|
||||||
|
# uses auth params passed to completion
|
||||||
|
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="sagemaker-runtime",
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||||
|
# boto3 automaticaly reads env variables
|
||||||
|
|
||||||
|
# we need to read region name from env
|
||||||
|
# I assume majority of users use .env for auth
|
||||||
|
region_name = (
|
||||||
|
get_secret("AWS_REGION_NAME") or
|
||||||
|
"us-west-2" # default to us-west-2 if user not specified
|
||||||
|
)
|
||||||
|
client = boto3.client(
|
||||||
|
service_name="sagemaker-runtime",
|
||||||
|
region_name=region_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||||
|
inference_params = deepcopy(optional_params)
|
||||||
|
inference_params.pop("stream", None)
|
||||||
|
|
||||||
|
## Load Config
|
||||||
|
config = litellm.SagemakerConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if k not in inference_params: # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
|
||||||
|
#### HF EMBEDDING LOGIC
|
||||||
|
data = json.dumps({
|
||||||
|
"text_inputs": input
|
||||||
|
}).encode('utf-8')
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
request_str = f"""
|
||||||
|
response = client.invoke_endpoint(
|
||||||
|
EndpointName={model},
|
||||||
|
ContentType="application/json",
|
||||||
|
Body={data},
|
||||||
|
CustomAttributes="accept_eula=true",
|
||||||
|
)"""
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key="",
|
||||||
|
additional_args={"complete_input_dict": data, "request_str": request_str},
|
||||||
|
)
|
||||||
|
## EMBEDDING CALL
|
||||||
|
try:
|
||||||
|
response = client.invoke_endpoint(
|
||||||
|
EndpointName=model,
|
||||||
|
ContentType="application/json",
|
||||||
|
Body=data,
|
||||||
|
CustomAttributes="accept_eula=true",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key="",
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
response = json.loads(response["Body"].read().decode("utf8"))
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key="",
|
||||||
|
original_response=response,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response}")
|
||||||
|
if "embedding" not in response:
|
||||||
|
raise SagemakerError(status_code=500, message="embedding not found in response")
|
||||||
|
embeddings = response['embedding']
|
||||||
|
|
||||||
|
if not isinstance(embeddings, list):
|
||||||
|
raise SagemakerError(status_code=422, message=f"Response not in expected format - {embeddings}")
|
||||||
|
|
||||||
|
|
||||||
|
output_data = []
|
||||||
|
for idx, embedding in enumerate(embeddings):
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response["object"] = "list"
|
||||||
|
model_response["data"] = output_data
|
||||||
|
model_response["model"] = model
|
||||||
|
|
||||||
|
input_tokens = 0
|
||||||
|
for text in input:
|
||||||
|
input_tokens+=len(encoding.encode(text))
|
||||||
|
|
||||||
|
model_response["usage"] = Usage(prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
|
@ -1906,6 +1906,16 @@ def embedding(
|
||||||
optional_params=kwargs,
|
optional_params=kwargs,
|
||||||
model_response= EmbeddingResponse()
|
model_response= EmbeddingResponse()
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "sagemaker":
|
||||||
|
response = sagemaker.embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
optional_params=kwargs,
|
||||||
|
model_response= EmbeddingResponse(),
|
||||||
|
print_verbose=print_verbose
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
args = locals()
|
args = locals()
|
||||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||||
|
|
|
@ -214,7 +214,14 @@ def test_aembedding_azure():
|
||||||
|
|
||||||
# test_aembedding_azure()
|
# test_aembedding_azure()
|
||||||
|
|
||||||
# def test_custom_openai_embedding():
|
def test_sagemaker_embeddings():
|
||||||
|
try:
|
||||||
|
response = litellm.embedding(model="sagemaker/berri-benchmarking-gpt-j-6b-fp16", input=["good morning from litellm", "this is another item"])
|
||||||
|
print(f"response: {response}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
test_sagemaker_embeddings()
|
||||||
|
# def local_proxy_embeddings():
|
||||||
# litellm.set_verbose=True
|
# litellm.set_verbose=True
|
||||||
# response = embedding(
|
# response = embedding(
|
||||||
# model="openai/custom_embedding",
|
# model="openai/custom_embedding",
|
||||||
|
@ -222,4 +229,5 @@ def test_aembedding_azure():
|
||||||
# api_base="http://0.0.0.0:8000/"
|
# api_base="http://0.0.0.0:8000/"
|
||||||
# )
|
# )
|
||||||
# print(response)
|
# print(response)
|
||||||
# test_custom_openai_embedding()
|
|
||||||
|
# local_proxy_embeddings()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue