mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(ollama.py): fix ollama embeddings - pass optional params
Fixes https://github.com/BerriAI/litellm/issues/5267
This commit is contained in:
parent
cc42f96d6a
commit
04d69464e2
1 changed files with 17 additions and 6 deletions
|
@ -4,8 +4,9 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
import types
|
import types
|
||||||
import uuid
|
import uuid
|
||||||
|
from copy import deepcopy
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
|
@ -510,7 +511,7 @@ async def ollama_aembeddings(
|
||||||
model: str,
|
model: str,
|
||||||
prompts: list,
|
prompts: list,
|
||||||
model_response: litellm.EmbeddingResponse,
|
model_response: litellm.EmbeddingResponse,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
):
|
):
|
||||||
|
@ -527,15 +528,25 @@ async def ollama_aembeddings(
|
||||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
|
input_data: Dict[str, Any] = {"model": model}
|
||||||
|
special_optional_params = ["truncate", "options", "keep_alive"]
|
||||||
|
|
||||||
|
for k, v in optional_params.items():
|
||||||
|
if k in special_optional_params:
|
||||||
|
input_data[k] = v
|
||||||
|
else:
|
||||||
|
# Ensure "options" is a dictionary before updating it
|
||||||
|
input_data.setdefault("options", {})
|
||||||
|
if isinstance(input_data["options"], dict):
|
||||||
|
input_data["options"].update({k: v})
|
||||||
total_input_tokens = 0
|
total_input_tokens = 0
|
||||||
output_data = []
|
output_data = []
|
||||||
|
|
||||||
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
for idx, prompt in enumerate(prompts):
|
for idx, prompt in enumerate(prompts):
|
||||||
data = {
|
data = deepcopy(input_data)
|
||||||
"model": model,
|
data["prompt"] = prompt
|
||||||
"prompt": prompt,
|
|
||||||
}
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=None,
|
input=None,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue