fix(main.py): map list input to ollama prompt input format

This commit is contained in:
Krrish Dholakia 2024-02-16 09:56:59 -08:00
parent dca9103b09
commit 944afcb5d1
2 changed files with 29 additions and 2 deletions

View file

@ -2590,10 +2590,28 @@ def embedding(
model_response=EmbeddingResponse(),
)
elif custom_llm_provider == "ollama":
ollama_input = None
if isinstance(input, list) and len(input) > 1:
raise litellm.BadRequestError(
message=f"Ollama Embeddings don't support batch embeddings",
model=model, # type: ignore
llm_provider="ollama", # type: ignore
)
if isinstance(input, list) and len(input) == 1:
ollama_input = "".join(input[0])
elif isinstance(input, str):
ollama_input = input
else:
raise litellm.BadRequestError(
message=f"Invalid input for ollama embeddings. input={input}",
model=model, # type: ignore
llm_provider="ollama", # type: ignore
)
if aembedding == True:
response = ollama.ollama_aembeddings(
model=model,
prompt=input,
prompt=ollama_input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,