mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(huggingface_restapi.py): fix embeddings for sentence-transformer models
This commit is contained in:
parent
435f0809b2
commit
80cb421e02
2 changed files with 50 additions and 17 deletions
|
@ -375,6 +375,16 @@ def embedding(
|
|||
else:
|
||||
embed_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||
|
||||
if "sentence-transformers" in model:
|
||||
if len(input) == 0:
|
||||
raise HuggingfaceError(status_code=400, message="sentence transformers requires 2+ sentences")
|
||||
data = {
|
||||
"inputs": {
|
||||
"source_sentence": input[0],
|
||||
"sentences": [ "That is a happy dog", "That is a very happy person", "Today is a sunny day" ]
|
||||
}
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"inputs": input
|
||||
}
|
||||
|
@ -402,8 +412,31 @@ def embedding(
|
|||
|
||||
embeddings = response.json()
|
||||
|
||||
if "error" in embeddings:
|
||||
raise HuggingfaceError(status_code=500, message=embeddings['error'])
|
||||
|
||||
output_data = []
|
||||
print(f"embeddings: {embeddings}")
|
||||
if "similarities" in embeddings:
|
||||
for idx, embedding in embeddings["similarities"]:
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding # flatten list returned from hf
|
||||
}
|
||||
)
|
||||
else:
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
if isinstance(embedding, float):
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding # flatten list returned from hf
|
||||
}
|
||||
)
|
||||
else:
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
|
|
|
@ -89,17 +89,17 @@ def test_bedrock_embedding():
|
|||
# test_bedrock_embedding()
|
||||
|
||||
# comment out hf tests - since hf endpoints are unstable
|
||||
# def test_hf_embedding():
|
||||
# try:
|
||||
# # huggingface/microsoft/codebert-base
|
||||
# # huggingface/facebook/bart-large
|
||||
# response = embedding(
|
||||
# model="huggingface/BAAI/bge-large-zh", input=["good morning from litellm", "this is another item"]
|
||||
# )
|
||||
# print(f"response:", response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# test_hf_embedding()
|
||||
def test_hf_embedding():
|
||||
try:
|
||||
# huggingface/microsoft/codebert-base
|
||||
# huggingface/facebook/bart-large
|
||||
response = embedding(
|
||||
model="huggingface/sentence-transformers/all-MiniLM-L6-v2", input=["good morning from litellm", "this is another item"]
|
||||
)
|
||||
print(f"response:", response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
test_hf_embedding()
|
||||
|
||||
# test async embeddings
|
||||
def test_aembedding():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue