Merge pull request #3470 from mbektas/fix-ollama-embeddings

support sync ollama embeddings
This commit is contained in:
Ishaan Jaff 2024-05-07 19:21:37 -07:00 committed by GitHub
commit 2725a55e7a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 88 additions and 12 deletions

View file

@ -474,3 +474,23 @@ async def ollama_aembeddings(
"total_tokens": total_input_tokens, "total_tokens": total_input_tokens,
} }
return model_response return model_response
def ollama_embeddings(
api_base: str,
model: str,
prompts: list,
optional_params=None,
logging_obj=None,
model_response=None,
encoding=None,
):
return asyncio.run(
ollama_aembeddings(
api_base,
model,
prompts,
optional_params,
logging_obj,
model_response,
encoding)
)

View file

@ -2950,16 +2950,16 @@ def embedding(
model=model, # type: ignore model=model, # type: ignore
llm_provider="ollama", # type: ignore llm_provider="ollama", # type: ignore
) )
if aembedding: ollama_embeddings_fn = ollama.ollama_aembeddings if aembedding else ollama.ollama_embeddings
response = ollama.ollama_aembeddings( response = ollama_embeddings_fn(
api_base=api_base, api_base=api_base,
model=model, model=model,
prompts=input, prompts=input,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
optional_params=optional_params, optional_params=optional_params,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
) )
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
response = sagemaker.embedding( response = sagemaker.embedding(
model=model, model=model,

View file

@ -1,3 +1,4 @@
import asyncio
import sys, os import sys, os
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
@ -10,10 +11,10 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
import litellm import litellm
from unittest import mock
## for ollama we can't test making the completion call ## for ollama we can't test making the completion call
from litellm.utils import get_optional_params, get_llm_provider from litellm.utils import EmbeddingResponse, get_optional_params, get_llm_provider
def test_get_ollama_params(): def test_get_ollama_params():
@ -58,3 +59,50 @@ def test_ollama_json_mode():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_ollama_json_mode() # test_ollama_json_mode()
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")
@mock.patch(
"litellm.llms.ollama.ollama_embeddings",
return_value=mock_ollama_embedding_response,
)
def test_ollama_embeddings(mock_embeddings):
# assert that ollama_embeddings is called with the right parameters
try:
embeddings = litellm.embedding(model="ollama/nomic-embed-text", input=["hello world"])
print(embeddings)
mock_embeddings.assert_called_once_with(
api_base="http://localhost:11434",
model="nomic-embed-text",
prompts=["hello world"],
optional_params=mock.ANY,
logging_obj=mock.ANY,
model_response=mock.ANY,
encoding=mock.ANY,
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_ollama_embeddings()
@mock.patch(
"litellm.llms.ollama.ollama_aembeddings",
return_value=mock_ollama_embedding_response,
)
def test_ollama_aembeddings(mock_aembeddings):
# assert that ollama_aembeddings is called with the right parameters
try:
embeddings = asyncio.run(litellm.aembedding(model="ollama/nomic-embed-text", input=["hello world"]))
print(embeddings)
mock_aembeddings.assert_called_once_with(
api_base="http://localhost:11434",
model="nomic-embed-text",
prompts=["hello world"],
optional_params=mock.ANY,
logging_obj=mock.ANY,
model_response=mock.ANY,
encoding=mock.ANY,
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_ollama_aembeddings()

View file

@ -24,6 +24,14 @@
# asyncio.run(test_ollama_aembeddings()) # asyncio.run(test_ollama_aembeddings())
# def test_ollama_embeddings():
# litellm.set_verbose = True
# input = "The food was delicious and the waiter..."
# response = litellm.embedding(model="ollama/mistral", input=input)
# print(response)
# test_ollama_embeddings()
# def test_ollama_streaming(): # def test_ollama_streaming():
# try: # try:
# litellm.set_verbose = False # litellm.set_verbose = False