fix(ollama.py): fix ollama embeddings - pass optional params

Fixes https://github.com/BerriAI/litellm/issues/5267
This commit is contained in:
Krrish Dholakia 2024-08-19 08:45:26 -07:00
parent cc42f96d6a
commit 04d69464e2

View file

@ -4,8 +4,9 @@ import time
import traceback import traceback
import types import types
import uuid import uuid
from copy import deepcopy
from itertools import chain from itertools import chain
from typing import List, Optional from typing import Any, Dict, List, Optional
import aiohttp import aiohttp
import httpx # type: ignore import httpx # type: ignore
@ -510,7 +511,7 @@ async def ollama_aembeddings(
model: str, model: str,
prompts: list, prompts: list,
model_response: litellm.EmbeddingResponse, model_response: litellm.EmbeddingResponse,
optional_params=None, optional_params: dict,
logging_obj=None, logging_obj=None,
encoding=None, encoding=None,
): ):
@ -527,15 +528,25 @@ async def ollama_aembeddings(
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
input_data: Dict[str, Any] = {"model": model}
special_optional_params = ["truncate", "options", "keep_alive"]
for k, v in optional_params.items():
if k in special_optional_params:
input_data[k] = v
else:
# Ensure "options" is a dictionary before updating it
input_data.setdefault("options", {})
if isinstance(input_data["options"], dict):
input_data["options"].update({k: v})
total_input_tokens = 0 total_input_tokens = 0
output_data = [] output_data = []
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
for idx, prompt in enumerate(prompts): for idx, prompt in enumerate(prompts):
data = { data = deepcopy(input_data)
"model": model, data["prompt"] = prompt
"prompt": prompt,
}
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=None, input=None,