mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(huggingface_restapi.py): fix embeddings for sentence-transformer models
This commit is contained in:
parent
435f0809b2
commit
80cb421e02
2 changed files with 50 additions and 17 deletions
|
@ -375,9 +375,19 @@ def embedding(
|
||||||
else:
|
else:
|
||||||
embed_url = f"https://api-inference.huggingface.co/models/{model}"
|
embed_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||||
|
|
||||||
data = {
|
if "sentence-transformers" in model:
|
||||||
"inputs": input
|
if len(input) == 0:
|
||||||
}
|
raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences")
|
||||||
|
data = {
|
||||||
|
"inputs": {
|
||||||
|
"source_sentence": input[0],
|
||||||
|
"sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
data = {
|
||||||
|
"inputs": input
|
||||||
|
}
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -402,15 +412,38 @@ def embedding(
|
||||||
|
|
||||||
embeddings = response.json()
|
embeddings = response.json()
|
||||||
|
|
||||||
|
if "error" in embeddings:
|
||||||
|
raise HuggingfaceError(status_code=500, message=embeddings['error'])
|
||||||
|
|
||||||
output_data = []
|
output_data = []
|
||||||
for idx, embedding in enumerate(embeddings):
|
print(f"embeddings: {embeddings}")
|
||||||
output_data.append(
|
if "similarities" in embeddings:
|
||||||
|
for idx, embedding in embeddings["similarities"]:
|
||||||
|
output_data.append(
|
||||||
{
|
{
|
||||||
"object": "embedding",
|
"object": "embedding",
|
||||||
"index": idx,
|
"index": idx,
|
||||||
"embedding": embedding[0][0] # flatten list returned from hf
|
"embedding": embedding # flatten list returned from hf
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
for idx, embedding in enumerate(embeddings):
|
||||||
|
if isinstance(embedding, float):
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding # flatten list returned from hf
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding[0][0] # flatten list returned from hf
|
||||||
|
}
|
||||||
|
)
|
||||||
model_response["object"] = "list"
|
model_response["object"] = "list"
|
||||||
model_response["data"] = output_data
|
model_response["data"] = output_data
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
|
|
|
@ -89,17 +89,17 @@ def test_bedrock_embedding():
|
||||||
# test_bedrock_embedding()
|
# test_bedrock_embedding()
|
||||||
|
|
||||||
# comment out hf tests - since hf endpoints are unstable
|
# comment out hf tests - since hf endpoints are unstable
|
||||||
# def test_hf_embedding():
|
def test_hf_embedding():
|
||||||
# try:
|
try:
|
||||||
# # huggingface/microsoft/codebert-base
|
# huggingface/microsoft/codebert-base
|
||||||
# # huggingface/facebook/bart-large
|
# huggingface/facebook/bart-large
|
||||||
# response = embedding(
|
response = embedding(
|
||||||
# model="huggingface/BAAI/bge-large-zh", input=["good morning from litellm", "this is another item"]
|
model="huggingface/sentence-transformers/all-MiniLM-L6-v2", input=["good morning from litellm", "this is another item"]
|
||||||
# )
|
)
|
||||||
# print(f"response:", response)
|
print(f"response:", response)
|
||||||
# except Exception as e:
|
except Exception as e:
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_hf_embedding()
|
test_hf_embedding()
|
||||||
|
|
||||||
# test async embeddings
|
# test async embeddings
|
||||||
def test_aembedding():
|
def test_aembedding():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue