From 1bd60287ba6ea849318aa3d6ebd34e2e9eb9e200 Mon Sep 17 00:00:00 2001 From: onukura <26293997+onukura@users.noreply.github.com> Date: Wed, 27 Mar 2024 21:39:19 +0000 Subject: [PATCH] Add a feature to ollama aembedding to accept batch input --- litellm/llms/ollama.py | 94 +++++++++++++++++++++--------------------- litellm/main.py | 20 +++------ 2 files changed, 52 insertions(+), 62 deletions(-) diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index 05a3134fc3..c68445f430 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -344,9 +344,9 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj): async def ollama_aembeddings( - api_base="http://localhost:11434", - model="llama2", - prompt="Why is the sky blue?", + api_base: str, + model: str, + prompts: list[str], optional_params=None, logging_obj=None, model_response=None, @@ -365,52 +365,52 @@ async def ollama_aembeddings( ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v - data = { - "model": model, - "prompt": prompt, - } - ## LOGGING - logging_obj.pre_call( - input=None, - api_key=None, - additional_args={"api_base": url, "complete_input_dict": data, "headers": {}}, - ) + total_input_tokens = 0 + output_data = [] timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes async with aiohttp.ClientSession(timeout=timeout) as session: - response = await session.post(url, json=data) - - if response.status != 200: - text = await response.text() - raise OllamaError(status_code=response.status, message=text) - - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key="", - original_response=response.text, - additional_args={ - "headers": None, - "api_base": api_base, - }, - ) - - response_json = await response.json() - embeddings = response_json["embedding"] - embeddings = [embeddings] # Ollama currently does not support batch embedding - ## RESPONSE OBJECT - output_data = [] - for idx, embedding in enumerate(embeddings): - output_data.append( - {"object": "embedding", "index": idx, "embedding": embedding} + for idx, prompt in enumerate(prompts): + data = { + "model": model, + "prompt": prompt, + } + ## LOGGING + logging_obj.pre_call( + input=None, + api_key=None, + additional_args={"api_base": url, "complete_input_dict": data, "headers": {}}, ) - model_response["object"] = "list" - model_response["data"] = output_data - model_response["model"] = model + + response = await session.post(url, json=data) + if response.status != 200: + text = await response.text() + raise OllamaError(status_code=response.status, message=text) - input_tokens = len(encoding.encode(prompt)) + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key="", + original_response=response.text, + additional_args={ + "headers": None, + "api_base": api_base, + }, + ) - model_response["usage"] = { - "prompt_tokens": input_tokens, - "total_tokens": input_tokens, - } - return model_response + response_json = await response.json() + embeddings: list[float] = response_json["embedding"] + output_data.append( + {"object": "embedding", "index": idx, "embedding": embeddings} + ) + + input_tokens = len(encoding.encode(prompt)) + total_input_tokens += input_tokens + + model_response["object"] = "list" + model_response["data"] = output_data + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": total_input_tokens, + "total_tokens": total_input_tokens, + } + return model_response diff --git a/litellm/main.py b/litellm/main.py index 0a44e00972..6592e58e20 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2795,29 +2795,19 @@ def embedding( or get_secret("OLLAMA_API_BASE") or "http://localhost:11434" ) - ollama_input = None - if isinstance(input, list) and len(input) > 1: - raise litellm.BadRequestError( - message=f"Ollama Embeddings don't support batch embeddings", - model=model, # type: ignore - llm_provider="ollama", # type: ignore - ) - if isinstance(input, list) and len(input) == 1: - ollama_input = "".join(input[0]) - elif isinstance(input, str): - ollama_input = input - else: + if isinstance(input ,str): + input = [input] + if not all(isinstance(item, str) for item in input): raise litellm.BadRequestError( message=f"Invalid input for ollama embeddings. input={input}", model=model, # type: ignore llm_provider="ollama", # type: ignore ) - - if aembedding == True: + if aembedding: response = ollama.ollama_aembeddings( api_base=api_base, model=model, - prompt=ollama_input, + prompts=input, encoding=encoding, logging_obj=logging, optional_params=optional_params,