diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index 3611ccd8b..05a3134fc 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -396,6 +396,7 @@ async def ollama_aembeddings( response_json = await response.json() embeddings = response_json["embedding"] + embeddings = [embeddings] # Ollama currently does not support batch embedding ## RESPONSE OBJECT output_data = [] for idx, embedding in enumerate(embeddings): diff --git a/litellm/main.py b/litellm/main.py index 3e875815e..b2c7da612 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2790,6 +2790,12 @@ def embedding( model_response=EmbeddingResponse(), ) elif custom_llm_provider == "ollama": + api_base = ( + litellm.api_base + or api_base + or get_secret("OLLAMA_API_BASE") + or "http://localhost:11434" + ) ollama_input = None if isinstance(input, list) and len(input) > 1: raise litellm.BadRequestError( @@ -2810,6 +2816,7 @@ def embedding( if aembedding == True: response = ollama.ollama_aembeddings( + api_base=api_base, model=model, prompt=ollama_input, encoding=encoding,