diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index ad8d408584..12b60c2c59 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -4,8 +4,9 @@ import time import traceback import types import uuid +from copy import deepcopy from itertools import chain -from typing import List, Optional +from typing import Any, Dict, List, Optional import aiohttp import httpx # type: ignore @@ -510,7 +511,7 @@ async def ollama_aembeddings( model: str, prompts: list, model_response: litellm.EmbeddingResponse, - optional_params=None, + optional_params: dict, logging_obj=None, encoding=None, ): @@ -527,15 +528,25 @@ async def ollama_aembeddings( ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v + input_data: Dict[str, Any] = {"model": model} + special_optional_params = ["truncate", "options", "keep_alive"] + + for k, v in optional_params.items(): + if k in special_optional_params: + input_data[k] = v + else: + # Ensure "options" is a dictionary before updating it + input_data.setdefault("options", {}) + if isinstance(input_data["options"], dict): + input_data["options"].update({k: v}) total_input_tokens = 0 output_data = [] + timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes async with aiohttp.ClientSession(timeout=timeout) as session: for idx, prompt in enumerate(prompts): - data = { - "model": model, - "prompt": prompt, - } + data = deepcopy(input_data) + data["prompt"] = prompt ## LOGGING logging_obj.pre_call( input=None,