fix(huggingface_restapi.py): fix embeddings for sentence-transformer models

This commit is contained in:
Krrish Dholakia 2023-11-01 16:36:46 -07:00
parent ab0a29e160
commit 2c4cb76ce5
2 changed files with 50 additions and 17 deletions

View file

@ -375,9 +375,19 @@ def embedding(
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
data = {
"inputs": input
}
if "sentence-transformers" in model:
if len(input) == 0:
raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences")
data = {
"inputs": {
"source_sentence": input[0],
"sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ]
}
}
else:
data = {
"inputs": input
}
## LOGGING
logging_obj.pre_call(
@ -402,15 +412,38 @@ def embedding(
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings['error'])
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
print(f"embeddings: {embeddings}")
if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][0] # flatten list returned from hf
"embedding": embedding # flatten list returned from hf
}
)
else:
for idx, embedding in enumerate(embeddings):
if isinstance(embedding, float):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding # flatten list returned from hf
}
)
else:
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][0] # flatten list returned from hf
}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model