(fix) unable to pass input_type parameter to Voyage AI embedding mode (#7276)

* VoyageEmbeddingConfig

* fix voyage logic to get params

* add voyage embedding transformation

* add get_provider_embedding_config

* use BaseEmbeddingConfig

* voyage clean up

* use llm http handler for embedding transformations

* test_voyage_ai_embedding_extra_params

* add voyage async

* test_voyage_ai_embedding_extra_params

* add async for llm http handler

* update BaseLLMEmbeddingTest

* test_voyage_ai_embedding_extra_params

* fix linting

* fix get_provider_embedding_config

* fix anthropic text test

* update location of base/chat/transformation

* fix import path

* fix IBMWatsonXAIConfig
This commit is contained in:
Ishaan Jaff 2024-12-17 19:23:49 -08:00 committed by GitHub
parent 63172e67f2
commit c7b288ce30
52 changed files with 535 additions and 66 deletions

View file

@ -1083,6 +1083,7 @@ from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
from .llms.deepinfra.chat.transformation import DeepInfraConfig from .llms.deepinfra.chat.transformation import DeepInfraConfig
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from .llms.groq.chat.transformation import GroqChatConfig from .llms.groq.chat.transformation import GroqChatConfig
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
from .llms.mistral.mistral_chat_transformation import MistralConfig from .llms.mistral.mistral_chat_transformation import MistralConfig
from .llms.openai.chat.o1_transformation import ( from .llms.openai.chat.o1_transformation import (

View file

@ -536,14 +536,6 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
) = litellm.XAIChatConfig()._get_openai_compatible_provider_info( ) = litellm.XAIChatConfig()._get_openai_compatible_provider_info(
api_base, api_key api_base, api_key
) )
elif custom_llm_provider == "voyage":
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
api_base = (
api_base
or get_secret_str("VOYAGE_API_BASE")
or "https://api.voyageai.com/v1"
) # type: ignore
dynamic_api_key = api_key or get_secret_str("VOYAGE_API_KEY")
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
api_base = ( api_base = (
api_base api_base

View file

@ -34,6 +34,8 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.OllamaChatConfig().get_supported_openai_params(model=model) return litellm.OllamaChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "anthropic": elif custom_llm_provider == "anthropic":
return litellm.AnthropicConfig().get_supported_openai_params(model=model) return litellm.AnthropicConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "anthropic_text":
return litellm.AnthropicTextConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "fireworks_ai": elif custom_llm_provider == "fireworks_ai":
if request_type == "embeddings": if request_type == "embeddings":
return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params( return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params(
@ -168,10 +170,17 @@ def get_supported_openai_params( # noqa: PLR0915
] ]
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
return litellm.IBMWatsonXChatConfig().get_supported_openai_params(model=model) return litellm.IBMWatsonXChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "custom_openai" or "text-completion-openai": elif custom_llm_provider == "watsonx_text":
return litellm.IBMWatsonXAIConfig().get_supported_openai_params(model=model)
elif (
custom_llm_provider == "custom_openai"
or custom_llm_provider == "text-completion-openai"
):
return litellm.OpenAITextCompletionConfig().get_supported_openai_params( return litellm.OpenAITextCompletionConfig().get_supported_openai_params(
model=model model=model
) )
elif custom_llm_provider == "predibase": elif custom_llm_provider == "predibase":
return litellm.PredibaseConfig().get_supported_openai_params(model=model) return litellm.PredibaseConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "voyage":
return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model)
return None return None

View file

@ -21,7 +21,7 @@ import litellm
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AllAnthropicToolsValues, AllAnthropicToolsValues,
AnthropicComputerTool, AnthropicComputerTool,

View file

@ -6,7 +6,7 @@ from typing import Optional, Union
import httpx import httpx
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
class AnthropicError(BaseLLMException): class AnthropicError(BaseLLMException):

View file

@ -11,13 +11,16 @@ from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
import httpx import httpx
import litellm import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.transformation import ( from litellm.llms.base_llm.chat.transformation import (
BaseConfig, BaseConfig,
BaseLLMException, BaseLLMException,
LiteLLMLoggingObj, LiteLLMLoggingObj,
) )
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ( from litellm.types.utils import (
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,

View file

@ -7,7 +7,7 @@ import litellm
from litellm.litellm_core_utils.prompt_templates.factory import ( from litellm.litellm_core_utils.prompt_templates.factory import (
convert_to_azure_openai_messages, convert_to_azure_openai_messages,
) )
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.utils import ModelResponse from litellm.types.utils import ModelResponse
from ....exceptions import UnsupportedParamsError from ....exceptions import UnsupportedParamsError
@ -18,7 +18,7 @@ from ....types.llms.openai import (
ChatCompletionToolParam, ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk, ChatCompletionToolParamFunctionChunk,
) )
from ...base_llm.transformation import BaseConfig from ...base_llm.chat.transformation import BaseConfig
from ..common_utils import AzureOpenAIError from ..common_utils import AzureOpenAIError
if TYPE_CHECKING: if TYPE_CHECKING:

View file

@ -3,7 +3,7 @@ from typing import Callable, Optional, Union
import httpx import httpx
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str

View file

@ -0,0 +1,94 @@
import types
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
TypedDict,
Union,
)
import httpx
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import EmbeddingResponse, ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseEmbeddingConfig(BaseConfig, ABC):
@abstractmethod
def transform_embedding_request(
self,
model: str,
input: Union[str, List[str], List[float], List[List[float]]],
optional_params: dict,
headers: dict,
) -> dict:
return {}
@abstractmethod
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> EmbeddingResponse:
return model_response
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
raise NotImplementedError(
"EmbeddingConfig does not need a request transformation for chat models"
)
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"EmbeddingConfig does not need a response transformation for chat models"
)

View file

@ -9,7 +9,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str, convert_content_list_to_str,
) )
from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ( from litellm.types.utils import (
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,

View file

@ -1,6 +1,6 @@
import httpx import httpx
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
class ClarifaiError(BaseLLMException): class ClarifaiError(BaseLLMException):

View file

@ -6,7 +6,7 @@ import httpx
import litellm import litellm
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.transformation import ( from litellm.llms.base_llm.chat.transformation import (
BaseConfig, BaseConfig,
BaseLLMException, BaseLLMException,
LiteLLMLoggingObj, LiteLLMLoggingObj,

View file

@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional,
import httpx import httpx
import litellm import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2 from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage from litellm.types.utils import ModelResponse, Usage

View file

@ -1,7 +1,7 @@
import json import json
from typing import List, Optional from typing import List, Optional
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ( from litellm.types.utils import (
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,

View file

@ -6,8 +6,10 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional,
import httpx import httpx
import litellm import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.litellm_core_utils.prompt_templates.common_utils import (
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ( from litellm.types.utils import (
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,

View file

@ -21,13 +21,15 @@ import litellm.types
import litellm.types.utils import litellm.types.utils
from litellm import verbose_logger from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.types.utils import EmbeddingResponse
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
if TYPE_CHECKING: if TYPE_CHECKING:
@ -403,6 +405,139 @@ class BaseLLMHTTPHandler:
return completion_stream, response.headers return completion_stream, response.headers
def embedding(
self,
model: str,
input: list,
timeout: float,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: Optional[str],
optional_params: dict,
model_response: EmbeddingResponse,
api_key: Optional[str] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
aembedding: bool = False,
headers={},
) -> EmbeddingResponse:
provider_config = ProviderConfigManager.get_provider_embedding_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=[],
optional_params=optional_params,
)
api_base = provider_config.get_complete_url(
api_base=api_base,
model=model,
)
data = provider_config.transform_embedding_request(
model=model,
input=input,
optional_params=optional_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
if aembedding is True:
return self.aembedding( # type: ignore
request_data=data,
api_base=api_base,
headers=headers,
model=model,
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
timeout=timeout,
client=client,
)
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
try:
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
return provider_config.transform_embedding_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
)
async def aembedding(
self,
request_data: dict,
api_base: str,
headers: dict,
model: str,
custom_llm_provider: str,
provider_config: BaseEmbeddingConfig,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> EmbeddingResponse:
if client is None or not isinstance(client, AsyncHTTPHandler):
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
)
else:
async_httpx_client = client
try:
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(request_data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
return provider_config.transform_embedding_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=request_data,
)
def _handle_error(self, e: Exception, provider_config: BaseConfig): def _handle_error(self, e: Exception, provider_config: BaseConfig):
status_code = getattr(e, "status_code", 500) status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None) error_headers = getattr(e, "headers", None)
@ -421,6 +556,3 @@ class BaseLLMHTTPHandler:
status_code=status_code, status_code=status_code,
headers=error_headers, headers=error_headers,
) )
def embedding(self):
pass

View file

@ -16,7 +16,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
prompt_factory, prompt_factory,
) )
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage from litellm.types.utils import Choices, Message, ModelResponse, Usage

View file

@ -2,7 +2,7 @@ from typing import Literal, Optional, Union
import httpx import httpx
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
class HuggingfaceError(BaseLLMException): class HuggingfaceError(BaseLLMException):

View file

@ -9,7 +9,7 @@ from typing import Any, Callable, List, Optional, Union
from httpx._models import Headers from httpx._models import Headers
import litellm import litellm
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.utils import Choices, Message, ModelResponse, Usage from litellm.utils import Choices, Message, ModelResponse, Usage

View file

@ -8,7 +8,7 @@ from typing import Any, Callable, List, Optional, Union
import httpx import httpx
import litellm import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,

View file

@ -4,8 +4,10 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx import httpx
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.litellm_core_utils.prompt_templates.common_utils import (
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage

View file

@ -2,7 +2,7 @@ from typing import Optional, Union
import httpx import httpx
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
class NLPCloudError(BaseLLMException): class NLPCloudError(BaseLLMException):

View file

@ -2,7 +2,7 @@ from typing import Union
import httpx import httpx
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
class OllamaError(BaseLLMException): class OllamaError(BaseLLMException):

View file

@ -13,7 +13,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
ollama_pt, ollama_pt,
) )
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,

View file

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx import httpx
import litellm import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage from litellm.types.utils import Choices, Message, ModelResponse, Usage

View file

@ -2,7 +2,7 @@ from typing import Optional, Union
import httpx import httpx
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
class OobaboogaError(BaseLLMException): class OobaboogaError(BaseLLMException):

View file

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
import httpx import httpx
import litellm import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
from litellm.types.utils import ModelResponse from litellm.types.utils import ModelResponse

View file

@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional, Union
import httpx import httpx
import openai import openai
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
class OpenAIError(BaseLLMException): class OpenAIError(BaseLLMException):

View file

@ -6,7 +6,10 @@ import types
from typing import List, Optional, Union, cast from typing import List, Optional, Union, cast
import litellm import litellm
from litellm.llms.base_llm.transformation import BaseConfig from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
AllPromptValues, AllPromptValues,
@ -14,7 +17,6 @@ from litellm.types.llms.openai import (
) )
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
from ..chat.gpt_transformation import OpenAIGPTConfig from ..chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import OpenAIError from ..common_utils import OpenAIError
from .utils import is_tokens_or_list_of_tokens from .utils import is_tokens_or_list_of_tokens

View file

@ -32,7 +32,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt, custom_prompt,
prompt_factory, prompt_factory,
) )
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import ( from litellm.types.utils import (

View file

@ -2,7 +2,7 @@ from typing import Union
from httpx import Headers from httpx import Headers
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
class PetalsError(BaseLLMException): class PetalsError(BaseLLMException):

View file

@ -4,7 +4,7 @@ from typing import Any, List, Optional, Union
from httpx import Headers, Response from httpx import Headers, Response
import litellm import litellm
from litellm.llms.base_llm.transformation import ( from litellm.llms.base_llm.chat.transformation import (
BaseConfig, BaseConfig,
BaseLLMException, BaseLLMException,
LiteLLMLoggingObj, LiteLLMLoggingObj,

View file

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
from httpx import Headers, Response from httpx import Headers, Response
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse from litellm.types.utils import ModelResponse

View file

@ -2,7 +2,7 @@ from typing import Optional, Union
import httpx import httpx
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
class PredibaseError(BaseLLMException): class PredibaseError(BaseLLMException):

View file

@ -4,9 +4,14 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx import httpx
import litellm import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.litellm_core_utils.prompt_templates.common_utils import (
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str convert_content_list_to_str,
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory )
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage from litellm.types.utils import Choices, Message, ModelResponse, Usage
from litellm.utils import token_counter from litellm.utils import token_counter

View file

@ -2,7 +2,7 @@ from typing import Optional, Union
import httpx import httpx
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
class ReplicateError(BaseLLMException): class ReplicateError(BaseLLMException):

View file

@ -11,7 +11,7 @@ from typing import Union
from httpx._models import Headers from httpx._models import Headers
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from ...openai.chat.gpt_transformation import OpenAIGPTConfig from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import SagemakerError from ..common_utils import SagemakerError

View file

@ -4,7 +4,7 @@ from typing import AsyncIterator, Iterator, List, Optional, Union
import httpx import httpx
from litellm import verbose_logger from litellm import verbose_logger
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import StreamingChatCompletionChunk from litellm.types.utils import StreamingChatCompletionChunk

View file

@ -17,7 +17,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt, custom_prompt,
prompt_factory, prompt_factory,
) )
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage from litellm.types.utils import ModelResponse, Usage

View file

@ -2,7 +2,7 @@ from typing import Optional, Union
import httpx import httpx
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
class TritonError(BaseLLMException): class TritonError(BaseLLMException):

View file

@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union
from httpx import Headers, Response from httpx import Headers, Response
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
from litellm.llms.base_llm.transformation import ( from litellm.llms.base_llm.chat.transformation import (
BaseConfig, BaseConfig,
BaseLLMException, BaseLLMException,
LiteLLMLoggingObj, LiteLLMLoggingObj,

View file

@ -3,7 +3,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union
import httpx import httpx
from litellm import supports_response_schema, supports_system_messages, verbose_logger from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.vertex_ai import PartType from litellm.types.llms.vertex_ai import PartType

View file

@ -34,7 +34,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj, convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj, convert_to_anthropic_image_obj,
) )
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,

View file

@ -0,0 +1,141 @@
import json
from typing import Any, List, Optional, Tuple, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import EmbeddingResponse, ModelResponse, Usage
class VoyageError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Union[dict, httpx.Headers] = {},
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.voyageai.com/v1/embeddings"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
headers=headers,
)
class VoyageEmbeddingConfig(BaseEmbeddingConfig):
"""
Reference: https://docs.voyageai.com/reference/embeddings-api
"""
def __init__(self) -> None:
pass
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base:
if not api_base.endswith("/embeddings"):
api_base = f"{api_base}/embeddings"
return api_base
return "https://api.voyageai.com/v1/embeddings"
def get_supported_openai_params(self, model: str) -> list:
return [
"encoding_format",
"dimensions",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI params to Voyage params
Reference: https://docs.voyageai.com/reference/embeddings-api
"""
if "encoding_format" in non_default_params:
optional_params["encoding_format"] = non_default_params["encoding_format"]
if "dimensions" in non_default_params:
optional_params["output_dimension"] = non_default_params["dimensions"]
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = (
get_secret_str("VOYAGE_API_KEY")
or get_secret_str("VOYAGE_AI_API_KEY")
or get_secret_str("VOYAGE_AI_TOKEN")
)
return {
"Authorization": f"Bearer {api_key}",
}
def transform_embedding_request(
self,
model: str,
input: Union[str, List[str], List[float], List[List[float]]],
optional_params: dict,
headers: dict,
) -> dict:
return {
"input": input,
"model": model,
**optional_params,
}
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> EmbeddingResponse:
try:
raw_response_json = raw_response.json()
except Exception:
raise VoyageError(
message=raw_response.text, status_code=raw_response.status_code
)
# model_response.usage
model_response.model = raw_response_json.get("model")
model_response.data = raw_response_json.get("data")
model_response.object = raw_response_json.get("object")
usage = Usage(
prompt_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0),
)
model_response.usage = usage
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return VoyageError(
message=error_message, status_code=status_code, headers=headers
)

View file

@ -5,7 +5,7 @@ import httpx
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.caching import InMemoryCache from litellm.caching import InMemoryCache
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import WatsonXAPIParams from litellm.types.llms.watsonx import WatsonXAPIParams

View file

@ -24,7 +24,8 @@ from typing import (
import httpx import httpx
import litellm import litellm
from litellm.llms.base_llm.transformation import BaseLLMException from litellm.litellm_core_utils.prompt_templates import factory as ptf
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
get_async_httpx_client, get_async_httpx_client,
@ -35,8 +36,7 @@ from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
from ...base import BaseLLM from ...base import BaseLLM
from ...base_llm.transformation import BaseConfig from ...base_llm.chat.transformation import BaseConfig
from litellm.litellm_core_utils.prompt_templates import factory as ptf
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
if TYPE_CHECKING: if TYPE_CHECKING:

View file

@ -3058,6 +3058,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "together_ai" or custom_llm_provider == "together_ai"
or custom_llm_provider == "openai_like" or custom_llm_provider == "openai_like"
or custom_llm_provider == "jina_ai" or custom_llm_provider == "jina_ai"
or custom_llm_provider == "voyage"
): # currently implemented aiohttp calls for just azure and openai, soon all. ): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -3632,10 +3633,10 @@ def embedding( # noqa: PLR0915
aembedding=aembedding, aembedding=aembedding,
) )
elif custom_llm_provider == "voyage": elif custom_llm_provider == "voyage":
api_key = api_key or litellm.api_key or get_secret_str("VOYAGE_API_KEY") response = base_llm_http_handler.embedding(
response = openai_chat_completions.embedding(
model=model, model=model,
input=input, input=input,
custom_llm_provider=custom_llm_provider,
api_base=api_base, api_base=api_base,
api_key=api_key, api_key=api_key,
logging_obj=logging, logging_obj=logging,

View file

@ -168,7 +168,8 @@ from typing import (
from openai import OpenAIError as OriginalError from openai import OpenAIError as OriginalError
from litellm.llms.base_llm.transformation import BaseConfig from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from ._logging import verbose_logger from ._logging import verbose_logger
from .caching.caching import ( from .caching.caching import (
@ -2377,6 +2378,21 @@ def get_optional_params_embeddings( # noqa: PLR0915
) )
final_params = {**optional_params, **kwargs} final_params = {**optional_params, **kwargs}
return final_params return final_params
elif custom_llm_provider == "voyage":
supported_params = get_supported_openai_params(
model=model,
custom_llm_provider="voyage",
request_type="embeddings",
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VoyageEmbeddingConfig().map_openai_params(
non_default_params=non_default_params,
optional_params={},
model=model,
drop_params=drop_params if drop_params is not None else False,
)
final_params = {**optional_params, **kwargs}
return final_params
elif custom_llm_provider == "fireworks_ai": elif custom_llm_provider == "fireworks_ai":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, model=model,
@ -6200,6 +6216,15 @@ class ProviderConfigManager:
return litellm.PetalsConfig() return litellm.PetalsConfig()
return litellm.OpenAIGPTConfig() return litellm.OpenAIGPTConfig()
@staticmethod
def get_provider_embedding_config(
model: str,
provider: LlmProviders,
) -> BaseEmbeddingConfig:
if litellm.LlmProviders.VOYAGE == provider:
return litellm.VoyageEmbeddingConfig()
raise ValueError(f"Provider {provider} does not support embedding config")
def get_end_user_id_for_cost_tracking( def get_end_user_id_for_cost_tracking(
litellm_params: dict, litellm_params: dict,

View file

@ -59,6 +59,10 @@ class BaseLLMEmbeddingTest(ABC):
print("async embedding response: ", response) print("async embedding response: ", response)
from openai.types.create_embedding_response import CreateEmbeddingResponse
CreateEmbeddingResponse.model_validate(response.model_dump())
def test_embedding_optional_params_max_retries(self): def test_embedding_optional_params_max_retries(self):
embedding_call_args = self.get_base_embedding_call_args() embedding_call_args = self.get_base_embedding_call_args()
optional_params = get_optional_params_embeddings( optional_params = get_optional_params_embeddings(

View file

@ -0,0 +1,56 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from base_embedding_unit_tests import BaseLLMEmbeddingTest
import litellm
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, MagicMock
class TestVoyageAI(BaseLLMEmbeddingTest):
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.VOYAGE
def get_base_embedding_call_args(self) -> dict:
return {
"model": "voyage/voyage-3-lite",
}
def test_voyage_ai_embedding_extra_params():
try:
client = HTTPHandler()
litellm.set_verbose = True
with patch.object(client, "post") as mock_client:
response = litellm.embedding(
model="voyage/voyage-3-lite",
input=["a"],
dimensions=512,
input_type="document",
client=client,
)
mock_client.assert_called_once()
json_data = json.loads(mock_client.call_args.kwargs["data"])
print("request data to voyage ai", json.dumps(json_data, indent=4))
# Assert the request parameters
assert json_data["input"] == ["a"]
assert json_data["model"] == "voyage-3-lite"
assert json_data["output_dimension"] == 512
assert json_data["input_type"] == "document"
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -348,7 +348,7 @@ async def test_add_and_delete_deployments(llm_router, model_list_flag_value):
from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders
from litellm.utils import ProviderConfigManager from litellm.utils import ProviderConfigManager
from litellm.llms.base_llm.transformation import BaseConfig from litellm.llms.base_llm.chat.transformation import BaseConfig
def _check_provider_config(config: BaseConfig, provider: LlmProviders): def _check_provider_config(config: BaseConfig, provider: LlmProviders):