mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(huggingface_restapi.py): fix embeddings for sentence-transformer models
This commit is contained in:
parent
ab0a29e160
commit
2c4cb76ce5
2 changed files with 50 additions and 17 deletions
|
@ -375,9 +375,19 @@ def embedding(
|
|||
else:
|
||||
embed_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||
|
||||
data = {
|
||||
"inputs": input
|
||||
}
|
||||
if "sentence-transformers" in model:
|
||||
if len(input) == 0:
|
||||
raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences")
|
||||
data = {
|
||||
"inputs": {
|
||||
"source_sentence": input[0],
|
||||
"sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ]
|
||||
}
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"inputs": input
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
|
@ -402,15 +412,38 @@ def embedding(
|
|||
|
||||
embeddings = response.json()
|
||||
|
||||
if "error" in embeddings:
|
||||
raise HuggingfaceError(status_code=500, message=embeddings['error'])
|
||||
|
||||
output_data = []
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
output_data.append(
|
||||
print(f"embeddings: {embeddings}")
|
||||
if "similarities" in embeddings:
|
||||
for idx, embedding in embeddings["similarities"]:
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding[0][0] # flatten list returned from hf
|
||||
"embedding": embedding # flatten list returned from hf
|
||||
}
|
||||
)
|
||||
else:
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
if isinstance(embedding, float):
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding # flatten list returned from hf
|
||||
}
|
||||
)
|
||||
else:
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding[0][0] # flatten list returned from hf
|
||||
}
|
||||
)
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = output_data
|
||||
model_response["model"] = model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue