(fix) unable to pass input_type parameter to Voyage AI embedding mode (#7276)

* VoyageEmbeddingConfig

* fix voyage logic to get params

* add voyage embedding transformation

* add get_provider_embedding_config

* use BaseEmbeddingConfig

* voyage clean up

* use llm http handler for embedding transformations

* test_voyage_ai_embedding_extra_params

* add voyage async

* test_voyage_ai_embedding_extra_params

* add async for llm http handler

* update BaseLLMEmbeddingTest

* test_voyage_ai_embedding_extra_params

* fix linting

* fix get_provider_embedding_config

* fix anthropic text test

* update location of base/chat/transformation

* fix import path

* fix IBMWatsonXAIConfig
This commit is contained in:
Ishaan Jaff 2024-12-17 19:23:49 -08:00 committed by GitHub
parent 63172e67f2
commit c7b288ce30
52 changed files with 535 additions and 66 deletions

View file

@ -1083,6 +1083,7 @@ from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
from .llms.deepinfra.chat.transformation import DeepInfraConfig
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from .llms.groq.chat.transformation import GroqChatConfig
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
from .llms.mistral.mistral_chat_transformation import MistralConfig
from .llms.openai.chat.o1_transformation import (

View file

@ -536,14 +536,6 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
) = litellm.XAIChatConfig()._get_openai_compatible_provider_info(
api_base, api_key
)
elif custom_llm_provider == "voyage":
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
api_base = (
api_base
or get_secret_str("VOYAGE_API_BASE")
or "https://api.voyageai.com/v1"
) # type: ignore
dynamic_api_key = api_key or get_secret_str("VOYAGE_API_KEY")
elif custom_llm_provider == "together_ai":
api_base = (
api_base

View file

@ -34,6 +34,8 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.OllamaChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "anthropic":
return litellm.AnthropicConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "anthropic_text":
return litellm.AnthropicTextConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "fireworks_ai":
if request_type == "embeddings":
return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params(
@ -168,10 +170,17 @@ def get_supported_openai_params( # noqa: PLR0915
]
elif custom_llm_provider == "watsonx":
return litellm.IBMWatsonXChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "custom_openai" or "text-completion-openai":
elif custom_llm_provider == "watsonx_text":
return litellm.IBMWatsonXAIConfig().get_supported_openai_params(model=model)
elif (
custom_llm_provider == "custom_openai"
or custom_llm_provider == "text-completion-openai"
):
return litellm.OpenAITextCompletionConfig().get_supported_openai_params(
model=model
)
elif custom_llm_provider == "predibase":
return litellm.PredibaseConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "voyage":
return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model)
return None

View file

@ -21,7 +21,7 @@ import litellm
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.anthropic import (
AllAnthropicToolsValues,
AnthropicComputerTool,

View file

@ -6,7 +6,7 @@ from typing import Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class AnthropicError(BaseLLMException):

View file

@ -11,13 +11,16 @@ from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.transformation import (
from litellm.llms.base_llm.chat.transformation import (
BaseConfig,
BaseLLMException,
LiteLLMLoggingObj,
)
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,

View file

@ -7,7 +7,7 @@ import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
convert_to_azure_openai_messages,
)
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.utils import ModelResponse
from ....exceptions import UnsupportedParamsError
@ -18,7 +18,7 @@ from ....types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from ...base_llm.transformation import BaseConfig
from ...base_llm.chat.transformation import BaseConfig
from ..common_utils import AzureOpenAIError
if TYPE_CHECKING:

View file

@ -3,7 +3,7 @@ from typing import Callable, Optional, Union
import httpx
from litellm._logging import verbose_logger
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str

View file

@ -0,0 +1,94 @@
import types
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
TypedDict,
Union,
)
import httpx
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import EmbeddingResponse, ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseEmbeddingConfig(BaseConfig, ABC):
@abstractmethod
def transform_embedding_request(
self,
model: str,
input: Union[str, List[str], List[float], List[List[float]]],
optional_params: dict,
headers: dict,
) -> dict:
return {}
@abstractmethod
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> EmbeddingResponse:
return model_response
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
raise NotImplementedError(
"EmbeddingConfig does not need a request transformation for chat models"
)
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"EmbeddingConfig does not need a response transformation for chat models"
)

View file

@ -9,7 +9,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,

View file

@ -1,6 +1,6 @@
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class ClarifaiError(BaseLLMException):

View file

@ -6,7 +6,7 @@ import httpx
import litellm
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.transformation import (
from litellm.llms.base_llm.chat.transformation import (
BaseConfig,
BaseLLMException,
LiteLLMLoggingObj,

View file

@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional,
import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage

View file

@ -1,7 +1,7 @@
import json
from typing import List, Optional
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,

View file

@ -6,8 +6,10 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional,
import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,

View file

@ -21,13 +21,15 @@ import litellm.types
import litellm.types.utils
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.utils import EmbeddingResponse
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
if TYPE_CHECKING:
@ -403,6 +405,139 @@ class BaseLLMHTTPHandler:
return completion_stream, response.headers
def embedding(
self,
model: str,
input: list,
timeout: float,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: Optional[str],
optional_params: dict,
model_response: EmbeddingResponse,
api_key: Optional[str] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
aembedding: bool = False,
headers={},
) -> EmbeddingResponse:
provider_config = ProviderConfigManager.get_provider_embedding_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=[],
optional_params=optional_params,
)
api_base = provider_config.get_complete_url(
api_base=api_base,
model=model,
)
data = provider_config.transform_embedding_request(
model=model,
input=input,
optional_params=optional_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
if aembedding is True:
return self.aembedding( # type: ignore
request_data=data,
api_base=api_base,
headers=headers,
model=model,
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
timeout=timeout,
client=client,
)
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
try:
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
return provider_config.transform_embedding_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
)
async def aembedding(
self,
request_data: dict,
api_base: str,
headers: dict,
model: str,
custom_llm_provider: str,
provider_config: BaseEmbeddingConfig,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> EmbeddingResponse:
if client is None or not isinstance(client, AsyncHTTPHandler):
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
)
else:
async_httpx_client = client
try:
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(request_data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
return provider_config.transform_embedding_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=request_data,
)
def _handle_error(self, e: Exception, provider_config: BaseConfig):
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
@ -421,6 +556,3 @@ class BaseLLMHTTPHandler:
status_code=status_code,
headers=error_headers,
)
def embedding(self):
pass

View file

@ -16,7 +16,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
prompt_factory,
)
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage

View file

@ -2,7 +2,7 @@ from typing import Literal, Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class HuggingfaceError(BaseLLMException):

View file

@ -9,7 +9,7 @@ from typing import Any, Callable, List, Optional, Union
from httpx._models import Headers
import litellm
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.utils import Choices, Message, ModelResponse, Usage

View file

@ -8,7 +8,7 @@ from typing import Any, Callable, List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,

View file

@ -4,8 +4,10 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.utils import ModelResponse, Usage

View file

@ -2,7 +2,7 @@ from typing import Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class NLPCloudError(BaseLLMException):

View file

@ -2,7 +2,7 @@ from typing import Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class OllamaError(BaseLLMException):

View file

@ -13,7 +13,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
ollama_pt,
)
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,

View file

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage

View file

@ -2,7 +2,7 @@ from typing import Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class OobaboogaError(BaseLLMException):

View file

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
from litellm.types.utils import ModelResponse

View file

@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional, Union
import httpx
import openai
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class OpenAIError(BaseLLMException):

View file

@ -6,7 +6,10 @@ import types
from typing import List, Optional, Union, cast
import litellm
from litellm.llms.base_llm.transformation import BaseConfig
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.types.llms.openai import (
AllMessageValues,
AllPromptValues,
@ -14,7 +17,6 @@ from litellm.types.llms.openai import (
)
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
from ..chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import OpenAIError
from .utils import is_tokens_or_list_of_tokens

View file

@ -32,7 +32,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import (

View file

@ -2,7 +2,7 @@ from typing import Union
from httpx import Headers
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class PetalsError(BaseLLMException):

View file

@ -4,7 +4,7 @@ from typing import Any, List, Optional, Union
from httpx import Headers, Response
import litellm
from litellm.llms.base_llm.transformation import (
from litellm.llms.base_llm.chat.transformation import (
BaseConfig,
BaseLLMException,
LiteLLMLoggingObj,

View file

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
from httpx import Headers, Response
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse

View file

@ -2,7 +2,7 @@ from typing import Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class PredibaseError(BaseLLMException):

View file

@ -4,9 +4,14 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
from litellm.utils import token_counter

View file

@ -2,7 +2,7 @@ from typing import Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class ReplicateError(BaseLLMException):

View file

@ -11,7 +11,7 @@ from typing import Union
from httpx._models import Headers
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import SagemakerError

View file

@ -4,7 +4,7 @@ from typing import AsyncIterator, Iterator, List, Optional, Union
import httpx
from litellm import verbose_logger
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import StreamingChatCompletionChunk

View file

@ -17,7 +17,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage

View file

@ -2,7 +2,7 @@ from typing import Optional, Union
import httpx
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class TritonError(BaseLLMException):

View file

@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union
from httpx import Headers, Response
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
from litellm.llms.base_llm.transformation import (
from litellm.llms.base_llm.chat.transformation import (
BaseConfig,
BaseLLMException,
LiteLLMLoggingObj,

View file

@ -3,7 +3,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union
import httpx
from litellm import supports_response_schema, supports_system_messages, verbose_logger
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.vertex_ai import PartType

View file

@ -34,7 +34,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
convert_generic_image_chunk_to_openai_image_obj,
convert_to_anthropic_image_obj,
)
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,

View file

@ -0,0 +1,141 @@
import json
from typing import Any, List, Optional, Tuple, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import EmbeddingResponse, ModelResponse, Usage
class VoyageError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Union[dict, httpx.Headers] = {},
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.voyageai.com/v1/embeddings"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
headers=headers,
)
class VoyageEmbeddingConfig(BaseEmbeddingConfig):
"""
Reference: https://docs.voyageai.com/reference/embeddings-api
"""
def __init__(self) -> None:
pass
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base:
if not api_base.endswith("/embeddings"):
api_base = f"{api_base}/embeddings"
return api_base
return "https://api.voyageai.com/v1/embeddings"
def get_supported_openai_params(self, model: str) -> list:
return [
"encoding_format",
"dimensions",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI params to Voyage params
Reference: https://docs.voyageai.com/reference/embeddings-api
"""
if "encoding_format" in non_default_params:
optional_params["encoding_format"] = non_default_params["encoding_format"]
if "dimensions" in non_default_params:
optional_params["output_dimension"] = non_default_params["dimensions"]
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = (
get_secret_str("VOYAGE_API_KEY")
or get_secret_str("VOYAGE_AI_API_KEY")
or get_secret_str("VOYAGE_AI_TOKEN")
)
return {
"Authorization": f"Bearer {api_key}",
}
def transform_embedding_request(
self,
model: str,
input: Union[str, List[str], List[float], List[List[float]]],
optional_params: dict,
headers: dict,
) -> dict:
return {
"input": input,
"model": model,
**optional_params,
}
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> EmbeddingResponse:
try:
raw_response_json = raw_response.json()
except Exception:
raise VoyageError(
message=raw_response.text, status_code=raw_response.status_code
)
# model_response.usage
model_response.model = raw_response_json.get("model")
model_response.data = raw_response_json.get("data")
model_response.object = raw_response_json.get("object")
usage = Usage(
prompt_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0),
)
model_response.usage = usage
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return VoyageError(
message=error_message, status_code=status_code, headers=headers
)

View file

@ -5,7 +5,7 @@ import httpx
import litellm
from litellm import verbose_logger
from litellm.caching import InMemoryCache
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import WatsonXAPIParams

View file

@ -24,7 +24,8 @@ from typing import (
import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseLLMException
from litellm.litellm_core_utils.prompt_templates import factory as ptf
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
@ -35,8 +36,7 @@ from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
from ...base import BaseLLM
from ...base_llm.transformation import BaseConfig
from litellm.litellm_core_utils.prompt_templates import factory as ptf
from ...base_llm.chat.transformation import BaseConfig
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
if TYPE_CHECKING:

View file

@ -3058,6 +3058,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "together_ai"
or custom_llm_provider == "openai_like"
or custom_llm_provider == "jina_ai"
or custom_llm_provider == "voyage"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
@ -3632,10 +3633,10 @@ def embedding( # noqa: PLR0915
aembedding=aembedding,
)
elif custom_llm_provider == "voyage":
api_key = api_key or litellm.api_key or get_secret_str("VOYAGE_API_KEY")
response = openai_chat_completions.embedding(
response = base_llm_http_handler.embedding(
model=model,
input=input,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
api_key=api_key,
logging_obj=logging,

View file

@ -168,7 +168,8 @@ from typing import (
from openai import OpenAIError as OriginalError
from litellm.llms.base_llm.transformation import BaseConfig
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from ._logging import verbose_logger
from .caching.caching import (
@ -2377,6 +2378,21 @@ def get_optional_params_embeddings( # noqa: PLR0915
)
final_params = {**optional_params, **kwargs}
return final_params
elif custom_llm_provider == "voyage":
supported_params = get_supported_openai_params(
model=model,
custom_llm_provider="voyage",
request_type="embeddings",
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VoyageEmbeddingConfig().map_openai_params(
non_default_params=non_default_params,
optional_params={},
model=model,
drop_params=drop_params if drop_params is not None else False,
)
final_params = {**optional_params, **kwargs}
return final_params
elif custom_llm_provider == "fireworks_ai":
supported_params = get_supported_openai_params(
model=model,
@ -6200,6 +6216,15 @@ class ProviderConfigManager:
return litellm.PetalsConfig()
return litellm.OpenAIGPTConfig()
@staticmethod
def get_provider_embedding_config(
model: str,
provider: LlmProviders,
) -> BaseEmbeddingConfig:
if litellm.LlmProviders.VOYAGE == provider:
return litellm.VoyageEmbeddingConfig()
raise ValueError(f"Provider {provider} does not support embedding config")
def get_end_user_id_for_cost_tracking(
litellm_params: dict,

View file

@ -59,6 +59,10 @@ class BaseLLMEmbeddingTest(ABC):
print("async embedding response: ", response)
from openai.types.create_embedding_response import CreateEmbeddingResponse
CreateEmbeddingResponse.model_validate(response.model_dump())
def test_embedding_optional_params_max_retries(self):
embedding_call_args = self.get_base_embedding_call_args()
optional_params = get_optional_params_embeddings(

View file

@ -0,0 +1,56 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from base_embedding_unit_tests import BaseLLMEmbeddingTest
import litellm
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, MagicMock
class TestVoyageAI(BaseLLMEmbeddingTest):
def get_custom_llm_provider(self) -> litellm.LlmProviders:
return litellm.LlmProviders.VOYAGE
def get_base_embedding_call_args(self) -> dict:
return {
"model": "voyage/voyage-3-lite",
}
def test_voyage_ai_embedding_extra_params():
try:
client = HTTPHandler()
litellm.set_verbose = True
with patch.object(client, "post") as mock_client:
response = litellm.embedding(
model="voyage/voyage-3-lite",
input=["a"],
dimensions=512,
input_type="document",
client=client,
)
mock_client.assert_called_once()
json_data = json.loads(mock_client.call_args.kwargs["data"])
print("request data to voyage ai", json.dumps(json_data, indent=4))
# Assert the request parameters
assert json_data["input"] == ["a"]
assert json_data["model"] == "voyage-3-lite"
assert json_data["output_dimension"] == 512
assert json_data["input_type"] == "document"
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -348,7 +348,7 @@ async def test_add_and_delete_deployments(llm_router, model_list_flag_value):
from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders
from litellm.utils import ProviderConfigManager
from litellm.llms.base_llm.transformation import BaseConfig
from litellm.llms.base_llm.chat.transformation import BaseConfig
def _check_provider_config(config: BaseConfig, provider: LlmProviders):