diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 09b375811..a7bf394f6 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -24,6 +24,7 @@ from openai import ( PermissionDeniedError, ) import httpx +from typing import Optional class AuthenticationError(AuthenticationError): # type: ignore @@ -50,11 +51,19 @@ class NotFoundError(NotFoundError): # type: ignore class BadRequestError(BadRequestError): # type: ignore - def __init__(self, message, model, llm_provider, response: httpx.Response): + def __init__( + self, message, model, llm_provider, response: Optional[httpx.Response] = None + ): self.status_code = 400 self.message = message self.model = model self.llm_provider = llm_provider + response = response or httpx.Response( + status_code=self.status_code, + request=httpx.Request( + method="GET", url="https://litellm.ai" + ), # mock request object + ) super().__init__( self.message, response=response, body=None ) # Call the base class constructor with the parameters it needs diff --git a/litellm/main.py b/litellm/main.py index 93ea3c644..2539039cd 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2590,10 +2590,28 @@ def embedding( model_response=EmbeddingResponse(), ) elif custom_llm_provider == "ollama": + ollama_input = None + if isinstance(input, list) and len(input) > 1: + raise litellm.BadRequestError( + message=f"Ollama Embeddings don't support batch embeddings", + model=model, # type: ignore + llm_provider="ollama", # type: ignore + ) + if isinstance(input, list) and len(input) == 1: + ollama_input = "".join(input[0]) + elif isinstance(input, str): + ollama_input = input + else: + raise litellm.BadRequestError( + message=f"Invalid input for ollama embeddings. input={input}", + model=model, # type: ignore + llm_provider="ollama", # type: ignore + ) + if aembedding == True: response = ollama.ollama_aembeddings( model=model, - prompt=input, + prompt=ollama_input, encoding=encoding, logging_obj=logging, optional_params=optional_params,