Merge pull request #3470 from mbektas/fix-ollama-embeddings

support sync ollama embeddings
This commit is contained in:
Ishaan Jaff 2024-05-07 19:21:37 -07:00 committed by GitHub
commit 2725a55e7a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 88 additions and 12 deletions

View file

@ -474,3 +474,23 @@ async def ollama_aembeddings(
"total_tokens": total_input_tokens,
}
return model_response
def ollama_embeddings(
api_base: str,
model: str,
prompts: list,
optional_params=None,
logging_obj=None,
model_response=None,
encoding=None,
):
return asyncio.run(
ollama_aembeddings(
api_base,
model,
prompts,
optional_params,
logging_obj,
model_response,
encoding)
)

View file

@ -2950,8 +2950,8 @@ def embedding(
model=model, # type: ignore
llm_provider="ollama", # type: ignore
)
if aembedding:
response = ollama.ollama_aembeddings(
ollama_embeddings_fn = ollama.ollama_aembeddings if aembedding else ollama.ollama_embeddings
response = ollama_embeddings_fn(
api_base=api_base,
model=model,
prompts=input,

View file

@ -1,3 +1,4 @@
import asyncio
import sys, os
import traceback
from dotenv import load_dotenv
@ -10,10 +11,10 @@ sys.path.insert(
) # Adds the parent directory to the system path
import pytest
import litellm
from unittest import mock
## for ollama we can't test making the completion call
from litellm.utils import get_optional_params, get_llm_provider
from litellm.utils import EmbeddingResponse, get_optional_params, get_llm_provider
def test_get_ollama_params():
@ -58,3 +59,50 @@ def test_ollama_json_mode():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_ollama_json_mode()
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")
@mock.patch(
"litellm.llms.ollama.ollama_embeddings",
return_value=mock_ollama_embedding_response,
)
def test_ollama_embeddings(mock_embeddings):
# assert that ollama_embeddings is called with the right parameters
try:
embeddings = litellm.embedding(model="ollama/nomic-embed-text", input=["hello world"])
print(embeddings)
mock_embeddings.assert_called_once_with(
api_base="http://localhost:11434",
model="nomic-embed-text",
prompts=["hello world"],
optional_params=mock.ANY,
logging_obj=mock.ANY,
model_response=mock.ANY,
encoding=mock.ANY,
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_ollama_embeddings()
@mock.patch(
"litellm.llms.ollama.ollama_aembeddings",
return_value=mock_ollama_embedding_response,
)
def test_ollama_aembeddings(mock_aembeddings):
# assert that ollama_aembeddings is called with the right parameters
try:
embeddings = asyncio.run(litellm.aembedding(model="ollama/nomic-embed-text", input=["hello world"]))
print(embeddings)
mock_aembeddings.assert_called_once_with(
api_base="http://localhost:11434",
model="nomic-embed-text",
prompts=["hello world"],
optional_params=mock.ANY,
logging_obj=mock.ANY,
model_response=mock.ANY,
encoding=mock.ANY,
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_ollama_aembeddings()

View file

@ -24,6 +24,14 @@
# asyncio.run(test_ollama_aembeddings())
# def test_ollama_embeddings():
# litellm.set_verbose = True
# input = "The food was delicious and the waiter..."
# response = litellm.embedding(model="ollama/mistral", input=input)
# print(response)
# test_ollama_embeddings()
# def test_ollama_streaming():
# try:
# litellm.set_verbose = False