forked from phoenix/litellm-mirror
Merge pull request #3470 from mbektas/fix-ollama-embeddings
support sync ollama embeddings
This commit is contained in:
commit
2725a55e7a
4 changed files with 88 additions and 12 deletions
|
@ -474,3 +474,23 @@ async def ollama_aembeddings(
|
|||
"total_tokens": total_input_tokens,
|
||||
}
|
||||
return model_response
|
||||
|
||||
def ollama_embeddings(
|
||||
api_base: str,
|
||||
model: str,
|
||||
prompts: list,
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
encoding=None,
|
||||
):
|
||||
return asyncio.run(
|
||||
ollama_aembeddings(
|
||||
api_base,
|
||||
model,
|
||||
prompts,
|
||||
optional_params,
|
||||
logging_obj,
|
||||
model_response,
|
||||
encoding)
|
||||
)
|
||||
|
|
|
@ -2950,16 +2950,16 @@ def embedding(
|
|||
model=model, # type: ignore
|
||||
llm_provider="ollama", # type: ignore
|
||||
)
|
||||
if aembedding:
|
||||
response = ollama.ollama_aembeddings(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
prompts=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
)
|
||||
ollama_embeddings_fn = ollama.ollama_aembeddings if aembedding else ollama.ollama_embeddings
|
||||
response = ollama_embeddings_fn(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
prompts=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
response = sagemaker.embedding(
|
||||
model=model,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
@ -10,10 +11,10 @@ sys.path.insert(
|
|||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
|
||||
from unittest import mock
|
||||
|
||||
## for ollama we can't test making the completion call
|
||||
from litellm.utils import get_optional_params, get_llm_provider
|
||||
from litellm.utils import EmbeddingResponse, get_optional_params, get_llm_provider
|
||||
|
||||
|
||||
def test_get_ollama_params():
|
||||
|
@ -58,3 +59,50 @@ def test_ollama_json_mode():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_ollama_json_mode()
|
||||
|
||||
|
||||
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")
|
||||
|
||||
@mock.patch(
|
||||
"litellm.llms.ollama.ollama_embeddings",
|
||||
return_value=mock_ollama_embedding_response,
|
||||
)
|
||||
def test_ollama_embeddings(mock_embeddings):
|
||||
# assert that ollama_embeddings is called with the right parameters
|
||||
try:
|
||||
embeddings = litellm.embedding(model="ollama/nomic-embed-text", input=["hello world"])
|
||||
print(embeddings)
|
||||
mock_embeddings.assert_called_once_with(
|
||||
api_base="http://localhost:11434",
|
||||
model="nomic-embed-text",
|
||||
prompts=["hello world"],
|
||||
optional_params=mock.ANY,
|
||||
logging_obj=mock.ANY,
|
||||
model_response=mock.ANY,
|
||||
encoding=mock.ANY,
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_ollama_embeddings()
|
||||
|
||||
@mock.patch(
|
||||
"litellm.llms.ollama.ollama_aembeddings",
|
||||
return_value=mock_ollama_embedding_response,
|
||||
)
|
||||
def test_ollama_aembeddings(mock_aembeddings):
|
||||
# assert that ollama_aembeddings is called with the right parameters
|
||||
try:
|
||||
embeddings = asyncio.run(litellm.aembedding(model="ollama/nomic-embed-text", input=["hello world"]))
|
||||
print(embeddings)
|
||||
mock_aembeddings.assert_called_once_with(
|
||||
api_base="http://localhost:11434",
|
||||
model="nomic-embed-text",
|
||||
prompts=["hello world"],
|
||||
optional_params=mock.ANY,
|
||||
logging_obj=mock.ANY,
|
||||
model_response=mock.ANY,
|
||||
encoding=mock.ANY,
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_ollama_aembeddings()
|
||||
|
|
|
@ -24,6 +24,14 @@
|
|||
|
||||
# asyncio.run(test_ollama_aembeddings())
|
||||
|
||||
# def test_ollama_embeddings():
|
||||
# litellm.set_verbose = True
|
||||
# input = "The food was delicious and the waiter..."
|
||||
# response = litellm.embedding(model="ollama/mistral", input=input)
|
||||
# print(response)
|
||||
|
||||
# test_ollama_embeddings()
|
||||
|
||||
# def test_ollama_streaming():
|
||||
# try:
|
||||
# litellm.set_verbose = False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue