fix(ollama.py): fix ollama embeddings - pass optional params

Fixes https://github.com/BerriAI/litellm/issues/5267
This commit is contained in:
Krrish Dholakia 2024-08-19 08:45:26 -07:00
parent cc42f96d6a
commit 04d69464e2

View file

@ -4,8 +4,9 @@ import time
import traceback
import types
import uuid
from copy import deepcopy
from itertools import chain
from typing import List, Optional
from typing import Any, Dict, List, Optional
import aiohttp
import httpx # type: ignore
@ -510,7 +511,7 @@ async def ollama_aembeddings(
model: str,
prompts: list,
model_response: litellm.EmbeddingResponse,
optional_params=None,
optional_params: dict,
logging_obj=None,
encoding=None,
):
@ -527,15 +528,25 @@ async def ollama_aembeddings(
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
input_data: Dict[str, Any] = {"model": model}
special_optional_params = ["truncate", "options", "keep_alive"]
for k, v in optional_params.items():
if k in special_optional_params:
input_data[k] = v
else:
# Ensure "options" is a dictionary before updating it
input_data.setdefault("options", {})
if isinstance(input_data["options"], dict):
input_data["options"].update({k: v})
total_input_tokens = 0
output_data = []
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session:
for idx, prompt in enumerate(prompts):
data = {
"model": model,
"prompt": prompt,
}
data = deepcopy(input_data)
data["prompt"] = prompt
## LOGGING
logging_obj.pre_call(
input=None,