From a1aeb7b404bb02e80ef71871b36d59f05f3217bc Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 16 Feb 2024 09:56:59 -0800 Subject: [PATCH] fix(main.py): map list input to ollama prompt input format --- litellm/exceptions.py | 11 ++++++++++- litellm/main.py | 20 +++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 09b3758112..a7bf394f6d 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -24,6 +24,7 @@ from openai import ( PermissionDeniedError, ) import httpx +from typing import Optional class AuthenticationError(AuthenticationError): # type: ignore @@ -50,11 +51,19 @@ class NotFoundError(NotFoundError): # type: ignore class BadRequestError(BadRequestError): # type: ignore - def __init__(self, message, model, llm_provider, response: httpx.Response): + def __init__( + self, message, model, llm_provider, response: Optional[httpx.Response] = None + ): self.status_code = 400 self.message = message self.model = model self.llm_provider = llm_provider + response = response or httpx.Response( + status_code=self.status_code, + request=httpx.Request( + method="GET", url="https://litellm.ai" + ), # mock request object + ) super().__init__( self.message, response=response, body=None ) # Call the base class constructor with the parameters it needs diff --git a/litellm/main.py b/litellm/main.py index 93ea3c6441..2539039cd7 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2590,10 +2590,28 @@ def embedding( model_response=EmbeddingResponse(), ) elif custom_llm_provider == "ollama": + ollama_input = None + if isinstance(input, list) and len(input) > 1: + raise litellm.BadRequestError( + message=f"Ollama Embeddings don't support batch embeddings", + model=model, # type: ignore + llm_provider="ollama", # type: ignore + ) + if isinstance(input, list) and len(input) == 1: + ollama_input = "".join(input[0]) + elif isinstance(input, str): + ollama_input = input + else: + raise litellm.BadRequestError( + message=f"Invalid input for ollama embeddings. input={input}", + model=model, # type: ignore + llm_provider="ollama", # type: ignore + ) + if aembedding == True: response = ollama.ollama_aembeddings( model=model, - prompt=input, + prompt=ollama_input, encoding=encoding, logging_obj=logging, optional_params=optional_params,