diff --git a/litellm/__init__.py b/litellm/__init__.py index 70e2412a95..c1b0d86d12 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1083,6 +1083,7 @@ from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig from .llms.deepinfra.chat.transformation import DeepInfraConfig from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig from .llms.groq.chat.transformation import GroqChatConfig +from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig from .llms.azure_ai.chat.transformation import AzureAIStudioConfig from .llms.mistral.mistral_chat_transformation import MistralConfig from .llms.openai.chat.o1_transformation import ( diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 57ab1ec7ef..8039dfb289 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -536,14 +536,6 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915 ) = litellm.XAIChatConfig()._get_openai_compatible_provider_info( api_base, api_key ) - elif custom_llm_provider == "voyage": - # voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1 - api_base = ( - api_base - or get_secret_str("VOYAGE_API_BASE") - or "https://api.voyageai.com/v1" - ) # type: ignore - dynamic_api_key = api_key or get_secret_str("VOYAGE_API_KEY") elif custom_llm_provider == "together_ai": api_base = ( api_base diff --git a/litellm/litellm_core_utils/get_supported_openai_params.py b/litellm/litellm_core_utils/get_supported_openai_params.py index 4e12d5ef82..38fca0b1c3 100644 --- a/litellm/litellm_core_utils/get_supported_openai_params.py +++ b/litellm/litellm_core_utils/get_supported_openai_params.py @@ -34,6 +34,8 @@ def get_supported_openai_params( # noqa: PLR0915 return litellm.OllamaChatConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "anthropic": return litellm.AnthropicConfig().get_supported_openai_params(model=model) + elif custom_llm_provider == "anthropic_text": + return litellm.AnthropicTextConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "fireworks_ai": if request_type == "embeddings": return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params( @@ -168,10 +170,17 @@ def get_supported_openai_params( # noqa: PLR0915 ] elif custom_llm_provider == "watsonx": return litellm.IBMWatsonXChatConfig().get_supported_openai_params(model=model) - elif custom_llm_provider == "custom_openai" or "text-completion-openai": + elif custom_llm_provider == "watsonx_text": + return litellm.IBMWatsonXAIConfig().get_supported_openai_params(model=model) + elif ( + custom_llm_provider == "custom_openai" + or custom_llm_provider == "text-completion-openai" + ): return litellm.OpenAITextCompletionConfig().get_supported_openai_params( model=model ) elif custom_llm_provider == "predibase": return litellm.PredibaseConfig().get_supported_openai_params(model=model) + elif custom_llm_provider == "voyage": + return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model) return None diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 39582d1314..b3328e048e 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -21,7 +21,7 @@ import litellm from litellm.constants import RESPONSE_FORMAT_TOOL_NAME from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.anthropic import ( AllAnthropicToolsValues, AnthropicComputerTool, diff --git a/litellm/llms/anthropic/common_utils.py b/litellm/llms/anthropic/common_utils.py index 8ef79f9505..409bbe2d82 100644 --- a/litellm/llms/anthropic/common_utils.py +++ b/litellm/llms/anthropic/common_utils.py @@ -6,7 +6,7 @@ from typing import Optional, Union import httpx -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException class AnthropicError(BaseLLMException): diff --git a/litellm/llms/anthropic/completion/transformation.py b/litellm/llms/anthropic/completion/transformation.py index 2435d7d4ad..df8064ddf4 100644 --- a/litellm/llms/anthropic/completion/transformation.py +++ b/litellm/llms/anthropic/completion/transformation.py @@ -11,13 +11,16 @@ from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union import httpx import litellm +from litellm.litellm_core_utils.prompt_templates.factory import ( + custom_prompt, + prompt_factory, +) from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator -from litellm.llms.base_llm.transformation import ( +from litellm.llms.base_llm.chat.transformation import ( BaseConfig, BaseLLMException, LiteLLMLoggingObj, ) -from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ( ChatCompletionToolCallChunk, diff --git a/litellm/llms/azure/chat/gpt_transformation.py b/litellm/llms/azure/chat/gpt_transformation.py index 5af1a675aa..d770803eb6 100644 --- a/litellm/llms/azure/chat/gpt_transformation.py +++ b/litellm/llms/azure/chat/gpt_transformation.py @@ -7,7 +7,7 @@ import litellm from litellm.litellm_core_utils.prompt_templates.factory import ( convert_to_azure_openai_messages, ) -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.types.utils import ModelResponse from ....exceptions import UnsupportedParamsError @@ -18,7 +18,7 @@ from ....types.llms.openai import ( ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk, ) -from ...base_llm.transformation import BaseConfig +from ...base_llm.chat.transformation import BaseConfig from ..common_utils import AzureOpenAIError if TYPE_CHECKING: diff --git a/litellm/llms/azure/common_utils.py b/litellm/llms/azure/common_utils.py index f54a5499c0..dfcb3d82b9 100644 --- a/litellm/llms/azure/common_utils.py +++ b/litellm/llms/azure/common_utils.py @@ -3,7 +3,7 @@ from typing import Callable, Optional, Union import httpx from litellm._logging import verbose_logger -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.secret_managers.main import get_secret_str diff --git a/litellm/llms/base_llm/transformation.py b/litellm/llms/base_llm/chat/transformation.py similarity index 100% rename from litellm/llms/base_llm/transformation.py rename to litellm/llms/base_llm/chat/transformation.py diff --git a/litellm/llms/base_llm/embedding/transformation.py b/litellm/llms/base_llm/embedding/transformation.py new file mode 100644 index 0000000000..7b2873b6d7 --- /dev/null +++ b/litellm/llms/base_llm/embedding/transformation.py @@ -0,0 +1,94 @@ +import types +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Optional, + TypedDict, + Union, +) + +import httpx + +from litellm.llms.base_llm.chat.transformation import BaseConfig +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import EmbeddingResponse, ModelResponse + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class BaseEmbeddingConfig(BaseConfig, ABC): + + @abstractmethod + def transform_embedding_request( + self, + model: str, + input: Union[str, List[str], List[float], List[List[float]]], + optional_params: dict, + headers: dict, + ) -> dict: + return {} + + @abstractmethod + def transform_embedding_response( + self, + model: str, + raw_response: httpx.Response, + model_response: EmbeddingResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> EmbeddingResponse: + return model_response + + def get_complete_url(self, api_base: Optional[str], model: str) -> str: + """ + OPTIONAL + + Get the complete url for the request + + Some providers need `model` in `api_base` + """ + return api_base or "" + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + raise NotImplementedError( + "EmbeddingConfig does not need a request transformation for chat models" + ) + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + raise NotImplementedError( + "EmbeddingConfig does not need a response transformation for chat models" + ) diff --git a/litellm/llms/clarifai/chat/transformation.py b/litellm/llms/clarifai/chat/transformation.py index fac16f7ca6..5dc22c284e 100644 --- a/litellm/llms/clarifai/chat/transformation.py +++ b/litellm/llms/clarifai/chat/transformation.py @@ -9,7 +9,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import ( convert_content_list_to_str, ) from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ( ChatCompletionToolCallChunk, diff --git a/litellm/llms/clarifai/common_utils.py b/litellm/llms/clarifai/common_utils.py index 0f249a0720..9190e8554a 100644 --- a/litellm/llms/clarifai/common_utils.py +++ b/litellm/llms/clarifai/common_utils.py @@ -1,6 +1,6 @@ import httpx -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException class ClarifaiError(BaseLLMException): diff --git a/litellm/llms/cloudflare/chat/transformation.py b/litellm/llms/cloudflare/chat/transformation.py index 4906f7b44e..596875919a 100644 --- a/litellm/llms/cloudflare/chat/transformation.py +++ b/litellm/llms/cloudflare/chat/transformation.py @@ -6,7 +6,7 @@ import httpx import litellm from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator -from litellm.llms.base_llm.transformation import ( +from litellm.llms.base_llm.chat.transformation import ( BaseConfig, BaseLLMException, LiteLLMLoggingObj, diff --git a/litellm/llms/cohere/chat/transformation.py b/litellm/llms/cohere/chat/transformation.py index 8002bf914f..39df1e021f 100644 --- a/litellm/llms/cohere/chat/transformation.py +++ b/litellm/llms/cohere/chat/transformation.py @@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, import httpx import litellm -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2 +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ModelResponse, Usage diff --git a/litellm/llms/cohere/common_utils.py b/litellm/llms/cohere/common_utils.py index 6aaad2b706..11ff73efc2 100644 --- a/litellm/llms/cohere/common_utils.py +++ b/litellm/llms/cohere/common_utils.py @@ -1,7 +1,7 @@ import json from typing import List, Optional -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ( ChatCompletionToolCallChunk, diff --git a/litellm/llms/cohere/completion/transformation.py b/litellm/llms/cohere/completion/transformation.py index 2240d4ea71..61d5ca5ad3 100644 --- a/litellm/llms/cohere/completion/transformation.py +++ b/litellm/llms/cohere/completion/transformation.py @@ -6,8 +6,10 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, import httpx import litellm -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException -from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + convert_content_list_to_str, +) +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ( ChatCompletionToolCallChunk, diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 51e1dae6b0..01043a6d9b 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -21,13 +21,15 @@ import litellm.types import litellm.types.utils from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, _get_httpx_client, get_async_httpx_client, ) +from litellm.types.utils import EmbeddingResponse from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager if TYPE_CHECKING: @@ -403,6 +405,139 @@ class BaseLLMHTTPHandler: return completion_stream, response.headers + def embedding( + self, + model: str, + input: list, + timeout: float, + custom_llm_provider: str, + logging_obj: LiteLLMLoggingObj, + api_base: Optional[str], + optional_params: dict, + model_response: EmbeddingResponse, + api_key: Optional[str] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + aembedding: bool = False, + headers={}, + ) -> EmbeddingResponse: + + provider_config = ProviderConfigManager.get_provider_embedding_config( + model=model, provider=litellm.LlmProviders(custom_llm_provider) + ) + # get config from model, custom llm provider + headers = provider_config.validate_environment( + api_key=api_key, + headers=headers, + model=model, + messages=[], + optional_params=optional_params, + ) + + api_base = provider_config.get_complete_url( + api_base=api_base, + model=model, + ) + + data = provider_config.transform_embedding_request( + model=model, + input=input, + optional_params=optional_params, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + + if aembedding is True: + return self.aembedding( # type: ignore + request_data=data, + api_base=api_base, + headers=headers, + model=model, + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + timeout=timeout, + client=client, + ) + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client + + try: + response = sync_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout, + ) + except Exception as e: + raise self._handle_error( + e=e, + provider_config=provider_config, + ) + + return provider_config.transform_embedding_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + ) + + async def aembedding( + self, + request_data: dict, + api_base: str, + headers: dict, + model: str, + custom_llm_provider: str, + provider_config: BaseEmbeddingConfig, + model_response: EmbeddingResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + ) -> EmbeddingResponse: + if client is None or not isinstance(client, AsyncHTTPHandler): + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders(custom_llm_provider) + ) + else: + async_httpx_client = client + + try: + response = await async_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(request_data), + timeout=timeout, + ) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + + return provider_config.transform_embedding_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=request_data, + ) + def _handle_error(self, e: Exception, provider_config: BaseConfig): status_code = getattr(e, "status_code", 500) error_headers = getattr(e, "headers", None) @@ -421,6 +556,3 @@ class BaseLLMHTTPHandler: status_code=status_code, headers=error_headers, ) - - def embedding(self): - pass diff --git a/litellm/llms/huggingface/chat/transformation.py b/litellm/llms/huggingface/chat/transformation.py index 8238d1be41..c1bdc9ca67 100644 --- a/litellm/llms/huggingface/chat/transformation.py +++ b/litellm/llms/huggingface/chat/transformation.py @@ -16,7 +16,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import ( prompt_factory, ) from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import Choices, Message, ModelResponse, Usage diff --git a/litellm/llms/huggingface/common_utils.py b/litellm/llms/huggingface/common_utils.py index c63a4a0d1d..d793b29874 100644 --- a/litellm/llms/huggingface/common_utils.py +++ b/litellm/llms/huggingface/common_utils.py @@ -2,7 +2,7 @@ from typing import Literal, Optional, Union import httpx -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException class HuggingfaceError(BaseLLMException): diff --git a/litellm/llms/maritalk.py b/litellm/llms/maritalk.py index 10df36394b..1c7c882fa2 100644 --- a/litellm/llms/maritalk.py +++ b/litellm/llms/maritalk.py @@ -9,7 +9,7 @@ from typing import Any, Callable, List, Optional, Union from httpx._models import Headers import litellm -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig from litellm.utils import Choices, Message, ModelResponse, Usage diff --git a/litellm/llms/nlp_cloud/chat/handler.py b/litellm/llms/nlp_cloud/chat/handler.py index e82086ebf3..959832ab88 100644 --- a/litellm/llms/nlp_cloud/chat/handler.py +++ b/litellm/llms/nlp_cloud/chat/handler.py @@ -8,7 +8,7 @@ from typing import Any, Callable, List, Optional, Union import httpx import litellm -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, diff --git a/litellm/llms/nlp_cloud/chat/transformation.py b/litellm/llms/nlp_cloud/chat/transformation.py index e547f38d22..42bef0f4e8 100644 --- a/litellm/llms/nlp_cloud/chat/transformation.py +++ b/litellm/llms/nlp_cloud/chat/transformation.py @@ -4,8 +4,10 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union import httpx -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException -from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + convert_content_list_to_str, +) +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues from litellm.utils import ModelResponse, Usage diff --git a/litellm/llms/nlp_cloud/common_utils.py b/litellm/llms/nlp_cloud/common_utils.py index 5488a2fd7a..232f56c970 100644 --- a/litellm/llms/nlp_cloud/common_utils.py +++ b/litellm/llms/nlp_cloud/common_utils.py @@ -2,7 +2,7 @@ from typing import Optional, Union import httpx -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException class NLPCloudError(BaseLLMException): diff --git a/litellm/llms/ollama/common_utils.py b/litellm/llms/ollama/common_utils.py index 38f82ee7dc..5cf213950c 100644 --- a/litellm/llms/ollama/common_utils.py +++ b/litellm/llms/ollama/common_utils.py @@ -2,7 +2,7 @@ from typing import Union import httpx -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException class OllamaError(BaseLLMException): diff --git a/litellm/llms/ollama/completion/transformation.py b/litellm/llms/ollama/completion/transformation.py index c77fe7f028..3ba3d29587 100644 --- a/litellm/llms/ollama/completion/transformation.py +++ b/litellm/llms/ollama/completion/transformation.py @@ -13,7 +13,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import ( ollama_pt, ) from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import ( AllMessageValues, diff --git a/litellm/llms/oobabooga/chat/transformation.py b/litellm/llms/oobabooga/chat/transformation.py index 6780991bea..79ccca840c 100644 --- a/litellm/llms/oobabooga/chat/transformation.py +++ b/litellm/llms/oobabooga/chat/transformation.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union import httpx import litellm -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import Choices, Message, ModelResponse, Usage diff --git a/litellm/llms/oobabooga/common_utils.py b/litellm/llms/oobabooga/common_utils.py index 3612fed407..82f8cda951 100644 --- a/litellm/llms/oobabooga/common_utils.py +++ b/litellm/llms/oobabooga/common_utils.py @@ -2,7 +2,7 @@ from typing import Optional, Union import httpx -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException class OobaboogaError(BaseLLMException): diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py index 87b66ddc69..c6e63edb8c 100644 --- a/litellm/llms/openai/chat/gpt_transformation.py +++ b/litellm/llms/openai/chat/gpt_transformation.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union, cast import httpx import litellm -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage from litellm.types.utils import ModelResponse diff --git a/litellm/llms/openai/common_utils.py b/litellm/llms/openai/common_utils.py index e5b926f6aa..87857f7ced 100644 --- a/litellm/llms/openai/common_utils.py +++ b/litellm/llms/openai/common_utils.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional, Union import httpx import openai -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException class OpenAIError(BaseLLMException): diff --git a/litellm/llms/openai/completion/transformation.py b/litellm/llms/openai/completion/transformation.py index e7ff85d557..85a9115c74 100644 --- a/litellm/llms/openai/completion/transformation.py +++ b/litellm/llms/openai/completion/transformation.py @@ -6,7 +6,10 @@ import types from typing import List, Optional, Union, cast import litellm -from litellm.llms.base_llm.transformation import BaseConfig +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + convert_content_list_to_str, +) +from litellm.llms.base_llm.chat.transformation import BaseConfig from litellm.types.llms.openai import ( AllMessageValues, AllPromptValues, @@ -14,7 +17,6 @@ from litellm.types.llms.openai import ( ) from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse -from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str from ..chat.gpt_transformation import OpenAIGPTConfig from ..common_utils import OpenAIError from .utils import is_tokens_or_list_of_tokens diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index b2d14a3187..ffac461f38 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -32,7 +32,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import ( custom_prompt, prompt_factory, ) -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS from litellm.secret_managers.main import get_secret_str from litellm.types.utils import ( diff --git a/litellm/llms/petals/common_utils.py b/litellm/llms/petals/common_utils.py index 9df4bad8eb..bffee338f2 100644 --- a/litellm/llms/petals/common_utils.py +++ b/litellm/llms/petals/common_utils.py @@ -2,7 +2,7 @@ from typing import Union from httpx import Headers -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException class PetalsError(BaseLLMException): diff --git a/litellm/llms/petals/completion/transformation.py b/litellm/llms/petals/completion/transformation.py index 52b8cd178d..76b7df7235 100644 --- a/litellm/llms/petals/completion/transformation.py +++ b/litellm/llms/petals/completion/transformation.py @@ -4,7 +4,7 @@ from typing import Any, List, Optional, Union from httpx import Headers, Response import litellm -from litellm.llms.base_llm.transformation import ( +from litellm.llms.base_llm.chat.transformation import ( BaseConfig, BaseLLMException, LiteLLMLoggingObj, diff --git a/litellm/llms/predibase/chat/transformation.py b/litellm/llms/predibase/chat/transformation.py index c6b9451dd3..016b9e700f 100644 --- a/litellm/llms/predibase/chat/transformation.py +++ b/litellm/llms/predibase/chat/transformation.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union from httpx import Headers, Response -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ModelResponse diff --git a/litellm/llms/predibase/common_utils.py b/litellm/llms/predibase/common_utils.py index f1506ce219..2dad586120 100644 --- a/litellm/llms/predibase/common_utils.py +++ b/litellm/llms/predibase/common_utils.py @@ -2,7 +2,7 @@ from typing import Optional, Union import httpx -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException class PredibaseError(BaseLLMException): diff --git a/litellm/llms/replicate/chat/transformation.py b/litellm/llms/replicate/chat/transformation.py index 184a2cb809..b4d8b008d5 100644 --- a/litellm/llms/replicate/chat/transformation.py +++ b/litellm/llms/replicate/chat/transformation.py @@ -4,9 +4,14 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union import httpx import litellm -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException -from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str -from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + convert_content_list_to_str, +) +from litellm.litellm_core_utils.prompt_templates.factory import ( + custom_prompt, + prompt_factory, +) +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import Choices, Message, ModelResponse, Usage from litellm.utils import token_counter diff --git a/litellm/llms/replicate/common_utils.py b/litellm/llms/replicate/common_utils.py index 98a5936ccf..c52b47a46a 100644 --- a/litellm/llms/replicate/common_utils.py +++ b/litellm/llms/replicate/common_utils.py @@ -2,7 +2,7 @@ from typing import Optional, Union import httpx -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException class ReplicateError(BaseLLMException): diff --git a/litellm/llms/sagemaker/chat/transformation.py b/litellm/llms/sagemaker/chat/transformation.py index f5df6e279d..42c7e0d5fc 100644 --- a/litellm/llms/sagemaker/chat/transformation.py +++ b/litellm/llms/sagemaker/chat/transformation.py @@ -11,7 +11,7 @@ from typing import Union from httpx._models import Headers -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException from ...openai.chat.gpt_transformation import OpenAIGPTConfig from ..common_utils import SagemakerError diff --git a/litellm/llms/sagemaker/common_utils.py b/litellm/llms/sagemaker/common_utils.py index 8fa450a8d5..49e4989ff1 100644 --- a/litellm/llms/sagemaker/common_utils.py +++ b/litellm/llms/sagemaker/common_utils.py @@ -4,7 +4,7 @@ from typing import AsyncIterator, Iterator, List, Optional, Union import httpx from litellm import verbose_logger -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.types.utils import StreamingChatCompletionChunk diff --git a/litellm/llms/sagemaker/completion/transformation.py b/litellm/llms/sagemaker/completion/transformation.py index 0a91819b7b..6e4d2ac9c5 100644 --- a/litellm/llms/sagemaker/completion/transformation.py +++ b/litellm/llms/sagemaker/completion/transformation.py @@ -17,7 +17,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import ( custom_prompt, prompt_factory, ) -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ModelResponse, Usage diff --git a/litellm/llms/triton/common_utils.py b/litellm/llms/triton/common_utils.py index 64ce011b95..d5372eee00 100644 --- a/litellm/llms/triton/common_utils.py +++ b/litellm/llms/triton/common_utils.py @@ -2,7 +2,7 @@ from typing import Optional, Union import httpx -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException class TritonError(BaseLLMException): diff --git a/litellm/llms/triton/completion/transformation.py b/litellm/llms/triton/completion/transformation.py index fafdc027ea..381af4953a 100644 --- a/litellm/llms/triton/completion/transformation.py +++ b/litellm/llms/triton/completion/transformation.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union from httpx import Headers, Response from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory -from litellm.llms.base_llm.transformation import ( +from litellm.llms.base_llm.chat.transformation import ( BaseConfig, BaseLLMException, LiteLLMLoggingObj, diff --git a/litellm/llms/vertex_ai/common_utils.py b/litellm/llms/vertex_ai/common_utils.py index 5fef37d313..a412a1f0db 100644 --- a/litellm/llms/vertex_ai/common_utils.py +++ b/litellm/llms/vertex_ai/common_utils.py @@ -3,7 +3,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union import httpx from litellm import supports_response_schema, supports_system_messages, verbose_logger -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.types.llms.vertex_ai import PartType diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 8c534a4fa7..c75cff1430 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -34,7 +34,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import ( convert_generic_image_chunk_to_openai_image_obj, convert_to_anthropic_image_obj, ) -from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, diff --git a/litellm/llms/voyage/embedding/transformation.py b/litellm/llms/voyage/embedding/transformation.py new file mode 100644 index 0000000000..6d4fb89ddc --- /dev/null +++ b/litellm/llms/voyage/embedding/transformation.py @@ -0,0 +1,141 @@ +import json +from typing import Any, List, Optional, Tuple, Union + +import httpx + +import litellm +from litellm._logging import verbose_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import EmbeddingResponse, ModelResponse, Usage + + +class VoyageError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + headers: Union[dict, httpx.Headers] = {}, + ): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url="https://api.voyageai.com/v1/embeddings" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + status_code=status_code, + message=message, + headers=headers, + ) + + +class VoyageEmbeddingConfig(BaseEmbeddingConfig): + """ + Reference: https://docs.voyageai.com/reference/embeddings-api + """ + + def __init__(self) -> None: + pass + + def get_complete_url(self, api_base: Optional[str], model: str) -> str: + if api_base: + if not api_base.endswith("/embeddings"): + api_base = f"{api_base}/embeddings" + return api_base + return "https://api.voyageai.com/v1/embeddings" + + def get_supported_openai_params(self, model: str) -> list: + return [ + "encoding_format", + "dimensions", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + """ + Map OpenAI params to Voyage params + + Reference: https://docs.voyageai.com/reference/embeddings-api + """ + if "encoding_format" in non_default_params: + optional_params["encoding_format"] = non_default_params["encoding_format"] + if "dimensions" in non_default_params: + optional_params["output_dimension"] = non_default_params["dimensions"] + return optional_params + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + ) -> dict: + if api_key is None: + api_key = ( + get_secret_str("VOYAGE_API_KEY") + or get_secret_str("VOYAGE_AI_API_KEY") + or get_secret_str("VOYAGE_AI_TOKEN") + ) + return { + "Authorization": f"Bearer {api_key}", + } + + def transform_embedding_request( + self, + model: str, + input: Union[str, List[str], List[float], List[List[float]]], + optional_params: dict, + headers: dict, + ) -> dict: + return { + "input": input, + "model": model, + **optional_params, + } + + def transform_embedding_response( + self, + model: str, + raw_response: httpx.Response, + model_response: EmbeddingResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str] = None, + request_data: dict = {}, + optional_params: dict = {}, + litellm_params: dict = {}, + ) -> EmbeddingResponse: + try: + raw_response_json = raw_response.json() + except Exception: + raise VoyageError( + message=raw_response.text, status_code=raw_response.status_code + ) + + # model_response.usage + model_response.model = raw_response_json.get("model") + model_response.data = raw_response_json.get("data") + model_response.object = raw_response_json.get("object") + + usage = Usage( + prompt_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0), + total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0), + ) + model_response.usage = usage + return model_response + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return VoyageError( + message=error_message, status_code=status_code, headers=headers + ) diff --git a/litellm/llms/watsonx/common_utils.py b/litellm/llms/watsonx/common_utils.py index b270f2d82b..82881fe796 100644 --- a/litellm/llms/watsonx/common_utils.py +++ b/litellm/llms/watsonx/common_utils.py @@ -5,7 +5,7 @@ import httpx import litellm from litellm import verbose_logger from litellm.caching import InMemoryCache -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.secret_managers.main import get_secret_str from litellm.types.llms.watsonx import WatsonXAPIParams diff --git a/litellm/llms/watsonx/completion/transformation.py b/litellm/llms/watsonx/completion/transformation.py index 9e68be930e..566b6ad2ce 100644 --- a/litellm/llms/watsonx/completion/transformation.py +++ b/litellm/llms/watsonx/completion/transformation.py @@ -24,7 +24,8 @@ from typing import ( import httpx import litellm -from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.litellm_core_utils.prompt_templates import factory as ptf +from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, get_async_httpx_client, @@ -35,8 +36,7 @@ from litellm.types.llms.watsonx import WatsonXAIEndpoint from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason from ...base import BaseLLM -from ...base_llm.transformation import BaseConfig -from litellm.litellm_core_utils.prompt_templates import factory as ptf +from ...base_llm.chat.transformation import BaseConfig from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token if TYPE_CHECKING: diff --git a/litellm/main.py b/litellm/main.py index 6a80a48452..ba7e1303ac 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3058,6 +3058,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse: or custom_llm_provider == "together_ai" or custom_llm_provider == "openai_like" or custom_llm_provider == "jina_ai" + or custom_llm_provider == "voyage" ): # currently implemented aiohttp calls for just azure and openai, soon all. # Await normally init_response = await loop.run_in_executor(None, func_with_context) @@ -3632,10 +3633,10 @@ def embedding( # noqa: PLR0915 aembedding=aembedding, ) elif custom_llm_provider == "voyage": - api_key = api_key or litellm.api_key or get_secret_str("VOYAGE_API_KEY") - response = openai_chat_completions.embedding( + response = base_llm_http_handler.embedding( model=model, input=input, + custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key, logging_obj=logging, diff --git a/litellm/utils.py b/litellm/utils.py index 256329fabb..8baafe21ed 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -168,7 +168,8 @@ from typing import ( from openai import OpenAIError as OriginalError -from litellm.llms.base_llm.transformation import BaseConfig +from litellm.llms.base_llm.chat.transformation import BaseConfig +from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig from ._logging import verbose_logger from .caching.caching import ( @@ -2377,6 +2378,21 @@ def get_optional_params_embeddings( # noqa: PLR0915 ) final_params = {**optional_params, **kwargs} return final_params + elif custom_llm_provider == "voyage": + supported_params = get_supported_openai_params( + model=model, + custom_llm_provider="voyage", + request_type="embeddings", + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.VoyageEmbeddingConfig().map_openai_params( + non_default_params=non_default_params, + optional_params={}, + model=model, + drop_params=drop_params if drop_params is not None else False, + ) + final_params = {**optional_params, **kwargs} + return final_params elif custom_llm_provider == "fireworks_ai": supported_params = get_supported_openai_params( model=model, @@ -6200,6 +6216,15 @@ class ProviderConfigManager: return litellm.PetalsConfig() return litellm.OpenAIGPTConfig() + @staticmethod + def get_provider_embedding_config( + model: str, + provider: LlmProviders, + ) -> BaseEmbeddingConfig: + if litellm.LlmProviders.VOYAGE == provider: + return litellm.VoyageEmbeddingConfig() + raise ValueError(f"Provider {provider} does not support embedding config") + def get_end_user_id_for_cost_tracking( litellm_params: dict, diff --git a/tests/llm_translation/base_embedding_unit_tests.py b/tests/llm_translation/base_embedding_unit_tests.py index 94edeccdf3..b06ca31e68 100644 --- a/tests/llm_translation/base_embedding_unit_tests.py +++ b/tests/llm_translation/base_embedding_unit_tests.py @@ -59,6 +59,10 @@ class BaseLLMEmbeddingTest(ABC): print("async embedding response: ", response) + from openai.types.create_embedding_response import CreateEmbeddingResponse + + CreateEmbeddingResponse.model_validate(response.model_dump()) + def test_embedding_optional_params_max_retries(self): embedding_call_args = self.get_base_embedding_call_args() optional_params = get_optional_params_embeddings( diff --git a/tests/llm_translation/test_voyage_ai.py b/tests/llm_translation/test_voyage_ai.py new file mode 100644 index 0000000000..ca49d53c39 --- /dev/null +++ b/tests/llm_translation/test_voyage_ai.py @@ -0,0 +1,56 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + + +from base_embedding_unit_tests import BaseLLMEmbeddingTest +import litellm +from litellm.llms.custom_httpx.http_handler import HTTPHandler +from unittest.mock import patch, MagicMock + + +class TestVoyageAI(BaseLLMEmbeddingTest): + def get_custom_llm_provider(self) -> litellm.LlmProviders: + return litellm.LlmProviders.VOYAGE + + def get_base_embedding_call_args(self) -> dict: + return { + "model": "voyage/voyage-3-lite", + } + + +def test_voyage_ai_embedding_extra_params(): + try: + + client = HTTPHandler() + litellm.set_verbose = True + + with patch.object(client, "post") as mock_client: + response = litellm.embedding( + model="voyage/voyage-3-lite", + input=["a"], + dimensions=512, + input_type="document", + client=client, + ) + + mock_client.assert_called_once() + json_data = json.loads(mock_client.call_args.kwargs["data"]) + + print("request data to voyage ai", json.dumps(json_data, indent=4)) + + # Assert the request parameters + assert json_data["input"] == ["a"] + assert json_data["model"] == "voyage-3-lite" + assert json_data["output_dimension"] == 512 + assert json_data["input_type"] == "document" + + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/tests/local_testing/test_config.py b/tests/local_testing/test_config.py index 213f5095ea..88ea633df7 100644 --- a/tests/local_testing/test_config.py +++ b/tests/local_testing/test_config.py @@ -348,7 +348,7 @@ async def test_add_and_delete_deployments(llm_router, model_list_flag_value): from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders from litellm.utils import ProviderConfigManager -from litellm.llms.base_llm.transformation import BaseConfig +from litellm.llms.base_llm.chat.transformation import BaseConfig def _check_provider_config(config: BaseConfig, provider: LlmProviders):