forked from phoenix/litellm-mirror
Merge pull request #3470 from mbektas/fix-ollama-embeddings
support sync ollama embeddings
This commit is contained in:
commit
2725a55e7a
4 changed files with 88 additions and 12 deletions
|
@ -474,3 +474,23 @@ async def ollama_aembeddings(
|
||||||
"total_tokens": total_input_tokens,
|
"total_tokens": total_input_tokens,
|
||||||
}
|
}
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
def ollama_embeddings(
|
||||||
|
api_base: str,
|
||||||
|
model: str,
|
||||||
|
prompts: list,
|
||||||
|
optional_params=None,
|
||||||
|
logging_obj=None,
|
||||||
|
model_response=None,
|
||||||
|
encoding=None,
|
||||||
|
):
|
||||||
|
return asyncio.run(
|
||||||
|
ollama_aembeddings(
|
||||||
|
api_base,
|
||||||
|
model,
|
||||||
|
prompts,
|
||||||
|
optional_params,
|
||||||
|
logging_obj,
|
||||||
|
model_response,
|
||||||
|
encoding)
|
||||||
|
)
|
||||||
|
|
|
@ -2950,16 +2950,16 @@ def embedding(
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
llm_provider="ollama", # type: ignore
|
llm_provider="ollama", # type: ignore
|
||||||
)
|
)
|
||||||
if aembedding:
|
ollama_embeddings_fn = ollama.ollama_aembeddings if aembedding else ollama.ollama_embeddings
|
||||||
response = ollama.ollama_aembeddings(
|
response = ollama_embeddings_fn(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
model=model,
|
model=model,
|
||||||
prompts=input,
|
prompts=input,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
response = sagemaker.embedding(
|
response = sagemaker.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import sys, os
|
import sys, os
|
||||||
import traceback
|
import traceback
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
@ -10,10 +11,10 @@ sys.path.insert(
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
import litellm
|
import litellm
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
## for ollama we can't test making the completion call
|
## for ollama we can't test making the completion call
|
||||||
from litellm.utils import get_optional_params, get_llm_provider
|
from litellm.utils import EmbeddingResponse, get_optional_params, get_llm_provider
|
||||||
|
|
||||||
|
|
||||||
def test_get_ollama_params():
|
def test_get_ollama_params():
|
||||||
|
@ -58,3 +59,50 @@ def test_ollama_json_mode():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_ollama_json_mode()
|
# test_ollama_json_mode()
|
||||||
|
|
||||||
|
|
||||||
|
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"litellm.llms.ollama.ollama_embeddings",
|
||||||
|
return_value=mock_ollama_embedding_response,
|
||||||
|
)
|
||||||
|
def test_ollama_embeddings(mock_embeddings):
|
||||||
|
# assert that ollama_embeddings is called with the right parameters
|
||||||
|
try:
|
||||||
|
embeddings = litellm.embedding(model="ollama/nomic-embed-text", input=["hello world"])
|
||||||
|
print(embeddings)
|
||||||
|
mock_embeddings.assert_called_once_with(
|
||||||
|
api_base="http://localhost:11434",
|
||||||
|
model="nomic-embed-text",
|
||||||
|
prompts=["hello world"],
|
||||||
|
optional_params=mock.ANY,
|
||||||
|
logging_obj=mock.ANY,
|
||||||
|
model_response=mock.ANY,
|
||||||
|
encoding=mock.ANY,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
# test_ollama_embeddings()
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"litellm.llms.ollama.ollama_aembeddings",
|
||||||
|
return_value=mock_ollama_embedding_response,
|
||||||
|
)
|
||||||
|
def test_ollama_aembeddings(mock_aembeddings):
|
||||||
|
# assert that ollama_aembeddings is called with the right parameters
|
||||||
|
try:
|
||||||
|
embeddings = asyncio.run(litellm.aembedding(model="ollama/nomic-embed-text", input=["hello world"]))
|
||||||
|
print(embeddings)
|
||||||
|
mock_aembeddings.assert_called_once_with(
|
||||||
|
api_base="http://localhost:11434",
|
||||||
|
model="nomic-embed-text",
|
||||||
|
prompts=["hello world"],
|
||||||
|
optional_params=mock.ANY,
|
||||||
|
logging_obj=mock.ANY,
|
||||||
|
model_response=mock.ANY,
|
||||||
|
encoding=mock.ANY,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
# test_ollama_aembeddings()
|
||||||
|
|
|
@ -24,6 +24,14 @@
|
||||||
|
|
||||||
# asyncio.run(test_ollama_aembeddings())
|
# asyncio.run(test_ollama_aembeddings())
|
||||||
|
|
||||||
|
# def test_ollama_embeddings():
|
||||||
|
# litellm.set_verbose = True
|
||||||
|
# input = "The food was delicious and the waiter..."
|
||||||
|
# response = litellm.embedding(model="ollama/mistral", input=input)
|
||||||
|
# print(response)
|
||||||
|
|
||||||
|
# test_ollama_embeddings()
|
||||||
|
|
||||||
# def test_ollama_streaming():
|
# def test_ollama_streaming():
|
||||||
# try:
|
# try:
|
||||||
# litellm.set_verbose = False
|
# litellm.set_verbose = False
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue