Merge pull request #2675 from onukura/ollama-embedding

Fix Ollama embedding
This commit is contained in:
Krish Dholakia 2024-03-26 16:08:28 -07:00 committed by GitHub
commit 7eb2c7942c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 0 deletions

View file

@ -396,6 +396,7 @@ async def ollama_aembeddings(
response_json = await response.json() response_json = await response.json()
embeddings = response_json["embedding"] embeddings = response_json["embedding"]
embeddings = [embeddings] # Ollama currently does not support batch embedding
## RESPONSE OBJECT ## RESPONSE OBJECT
output_data = [] output_data = []
for idx, embedding in enumerate(embeddings): for idx, embedding in enumerate(embeddings):

View file

@ -2790,6 +2790,12 @@ def embedding(
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
) )
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
api_base = (
litellm.api_base
or api_base
or get_secret("OLLAMA_API_BASE")
or "http://localhost:11434"
)
ollama_input = None ollama_input = None
if isinstance(input, list) and len(input) > 1: if isinstance(input, list) and len(input) > 1:
raise litellm.BadRequestError( raise litellm.BadRequestError(
@ -2810,6 +2816,7 @@ def embedding(
if aembedding == True: if aembedding == True:
response = ollama.ollama_aembeddings( response = ollama.ollama_aembeddings(
api_base=api_base,
model=model, model=model,
prompt=ollama_input, prompt=ollama_input,
encoding=encoding, encoding=encoding,