fix(main.py): map list input to ollama prompt input format

This commit is contained in:
Krrish Dholakia 2024-02-16 09:56:59 -08:00 committed by ishaan-jaff
parent 3fffe96f97
commit b3d48da640
2 changed files with 29 additions and 2 deletions

View file

@ -24,6 +24,7 @@ from openai import (
PermissionDeniedError, PermissionDeniedError,
) )
import httpx import httpx
from typing import Optional
class AuthenticationError(AuthenticationError): # type: ignore class AuthenticationError(AuthenticationError): # type: ignore
@ -50,11 +51,19 @@ class NotFoundError(NotFoundError): # type: ignore
class BadRequestError(BadRequestError): # type: ignore class BadRequestError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response): def __init__(
self, message, model, llm_provider, response: Optional[httpx.Response] = None
):
self.status_code = 400 self.status_code = 400
self.message = message self.message = message
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
response = response or httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
)
super().__init__( super().__init__(
self.message, response=response, body=None self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs

View file

@ -2590,10 +2590,28 @@ def embedding(
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
) )
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
ollama_input = None
if isinstance(input, list) and len(input) > 1:
raise litellm.BadRequestError(
message=f"Ollama Embeddings don't support batch embeddings",
model=model, # type: ignore
llm_provider="ollama", # type: ignore
)
if isinstance(input, list) and len(input) == 1:
ollama_input = "".join(input[0])
elif isinstance(input, str):
ollama_input = input
else:
raise litellm.BadRequestError(
message=f"Invalid input for ollama embeddings. input={input}",
model=model, # type: ignore
llm_provider="ollama", # type: ignore
)
if aembedding == True: if aembedding == True:
response = ollama.ollama_aembeddings( response = ollama.ollama_aembeddings(
model=model, model=model,
prompt=input, prompt=ollama_input,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
optional_params=optional_params, optional_params=optional_params,