fix(huggingface_restapi.py): fix embeddings for sentence-transformer models

This commit is contained in:
Krrish Dholakia 2023-11-01 16:36:46 -07:00
parent 435f0809b2
commit 80cb421e02
2 changed files with 50 additions and 17 deletions

View file

@ -375,9 +375,19 @@ def embedding(
else: else:
embed_url = f"https://api-inference.huggingface.co/models/{model}" embed_url = f"https://api-inference.huggingface.co/models/{model}"
data = { if "sentence-transformers" in model:
"inputs": input if len(input) == 0:
} raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences")
data = {
"inputs": {
"source_sentence": input[0],
"sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ]
}
}
else:
data = {
"inputs": input
}
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -402,15 +412,38 @@ def embedding(
embeddings = response.json() embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings['error'])
output_data = [] output_data = []
for idx, embedding in enumerate(embeddings): print(f"embeddings: {embeddings}")
output_data.append( if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]:
output_data.append(
{ {
"object": "embedding", "object": "embedding",
"index": idx, "index": idx,
"embedding": embedding[0][0] # flatten list returned from hf "embedding": embedding # flatten list returned from hf
} }
) )
else:
for idx, embedding in enumerate(embeddings):
if isinstance(embedding, float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding # flatten list returned from hf
}
)
else:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][0] # flatten list returned from hf
}
)
model_response["object"] = "list" model_response["object"] = "list"
model_response["data"] = output_data model_response["data"] = output_data
model_response["model"] = model model_response["model"] = model

View file

@ -89,17 +89,17 @@ def test_bedrock_embedding():
# test_bedrock_embedding() # test_bedrock_embedding()
# comment out hf tests - since hf endpoints are unstable # comment out hf tests - since hf endpoints are unstable
# def test_hf_embedding(): def test_hf_embedding():
# try: try:
# # huggingface/microsoft/codebert-base # huggingface/microsoft/codebert-base
# # huggingface/facebook/bart-large # huggingface/facebook/bart-large
# response = embedding( response = embedding(
# model="huggingface/BAAI/bge-large-zh", input=["good morning from litellm", "this is another item"] model="huggingface/sentence-transformers/all-MiniLM-L6-v2", input=["good morning from litellm", "this is another item"]
# ) )
# print(f"response:", response) print(f"response:", response)
# except Exception as e: except Exception as e:
# pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_hf_embedding() test_hf_embedding()
# test async embeddings # test async embeddings
def test_aembedding(): def test_aembedding():