mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(fix) unable to pass input_type parameter to Voyage AI embedding mode (#7276)
* VoyageEmbeddingConfig * fix voyage logic to get params * add voyage embedding transformation * add get_provider_embedding_config * use BaseEmbeddingConfig * voyage clean up * use llm http handler for embedding transformations * test_voyage_ai_embedding_extra_params * add voyage async * test_voyage_ai_embedding_extra_params * add async for llm http handler * update BaseLLMEmbeddingTest * test_voyage_ai_embedding_extra_params * fix linting * fix get_provider_embedding_config * fix anthropic text test * update location of base/chat/transformation * fix import path * fix IBMWatsonXAIConfig
This commit is contained in:
parent
63172e67f2
commit
c7b288ce30
52 changed files with 535 additions and 66 deletions
|
@ -1083,6 +1083,7 @@ from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
|
||||||
from .llms.deepinfra.chat.transformation import DeepInfraConfig
|
from .llms.deepinfra.chat.transformation import DeepInfraConfig
|
||||||
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
||||||
from .llms.groq.chat.transformation import GroqChatConfig
|
from .llms.groq.chat.transformation import GroqChatConfig
|
||||||
|
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
||||||
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||||
from .llms.openai.chat.o1_transformation import (
|
from .llms.openai.chat.o1_transformation import (
|
||||||
|
|
|
@ -536,14 +536,6 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
||||||
) = litellm.XAIChatConfig()._get_openai_compatible_provider_info(
|
) = litellm.XAIChatConfig()._get_openai_compatible_provider_info(
|
||||||
api_base, api_key
|
api_base, api_key
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "voyage":
|
|
||||||
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
|
|
||||||
api_base = (
|
|
||||||
api_base
|
|
||||||
or get_secret_str("VOYAGE_API_BASE")
|
|
||||||
or "https://api.voyageai.com/v1"
|
|
||||||
) # type: ignore
|
|
||||||
dynamic_api_key = api_key or get_secret_str("VOYAGE_API_KEY")
|
|
||||||
elif custom_llm_provider == "together_ai":
|
elif custom_llm_provider == "together_ai":
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base
|
api_base
|
||||||
|
|
|
@ -34,6 +34,8 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
return litellm.OllamaChatConfig().get_supported_openai_params(model=model)
|
return litellm.OllamaChatConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "anthropic":
|
elif custom_llm_provider == "anthropic":
|
||||||
return litellm.AnthropicConfig().get_supported_openai_params(model=model)
|
return litellm.AnthropicConfig().get_supported_openai_params(model=model)
|
||||||
|
elif custom_llm_provider == "anthropic_text":
|
||||||
|
return litellm.AnthropicTextConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "fireworks_ai":
|
elif custom_llm_provider == "fireworks_ai":
|
||||||
if request_type == "embeddings":
|
if request_type == "embeddings":
|
||||||
return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params(
|
return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params(
|
||||||
|
@ -168,10 +170,17 @@ def get_supported_openai_params( # noqa: PLR0915
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "watsonx":
|
elif custom_llm_provider == "watsonx":
|
||||||
return litellm.IBMWatsonXChatConfig().get_supported_openai_params(model=model)
|
return litellm.IBMWatsonXChatConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "custom_openai" or "text-completion-openai":
|
elif custom_llm_provider == "watsonx_text":
|
||||||
|
return litellm.IBMWatsonXAIConfig().get_supported_openai_params(model=model)
|
||||||
|
elif (
|
||||||
|
custom_llm_provider == "custom_openai"
|
||||||
|
or custom_llm_provider == "text-completion-openai"
|
||||||
|
):
|
||||||
return litellm.OpenAITextCompletionConfig().get_supported_openai_params(
|
return litellm.OpenAITextCompletionConfig().get_supported_openai_params(
|
||||||
model=model
|
model=model
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "predibase":
|
elif custom_llm_provider == "predibase":
|
||||||
return litellm.PredibaseConfig().get_supported_openai_params(model=model)
|
return litellm.PredibaseConfig().get_supported_openai_params(model=model)
|
||||||
|
elif custom_llm_provider == "voyage":
|
||||||
|
return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model)
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -21,7 +21,7 @@ import litellm
|
||||||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
|
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.types.llms.anthropic import (
|
from litellm.types.llms.anthropic import (
|
||||||
AllAnthropicToolsValues,
|
AllAnthropicToolsValues,
|
||||||
AnthropicComputerTool,
|
AnthropicComputerTool,
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
class AnthropicError(BaseLLMException):
|
class AnthropicError(BaseLLMException):
|
||||||
|
|
|
@ -11,13 +11,16 @@ from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
custom_prompt,
|
||||||
|
prompt_factory,
|
||||||
|
)
|
||||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
from litellm.llms.base_llm.transformation import (
|
from litellm.llms.base_llm.chat.transformation import (
|
||||||
BaseConfig,
|
BaseConfig,
|
||||||
BaseLLMException,
|
BaseLLMException,
|
||||||
LiteLLMLoggingObj,
|
LiteLLMLoggingObj,
|
||||||
)
|
)
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
|
|
|
@ -7,7 +7,7 @@ import litellm
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
convert_to_azure_openai_messages,
|
convert_to_azure_openai_messages,
|
||||||
)
|
)
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.types.utils import ModelResponse
|
from litellm.types.utils import ModelResponse
|
||||||
|
|
||||||
from ....exceptions import UnsupportedParamsError
|
from ....exceptions import UnsupportedParamsError
|
||||||
|
@ -18,7 +18,7 @@ from ....types.llms.openai import (
|
||||||
ChatCompletionToolParam,
|
ChatCompletionToolParam,
|
||||||
ChatCompletionToolParamFunctionChunk,
|
ChatCompletionToolParamFunctionChunk,
|
||||||
)
|
)
|
||||||
from ...base_llm.transformation import BaseConfig
|
from ...base_llm.chat.transformation import BaseConfig
|
||||||
from ..common_utils import AzureOpenAIError
|
from ..common_utils import AzureOpenAIError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Callable, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
|
||||||
|
|
||||||
|
|
94
litellm/llms/base_llm/embedding/transformation.py
Normal file
94
litellm/llms/base_llm/embedding/transformation.py
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
import types
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
TypedDict,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import EmbeddingResponse, ModelResponse
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LiteLLMLoggingObj = Any
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEmbeddingConfig(BaseConfig, ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def transform_embedding_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: Union[str, List[str], List[float], List[List[float]]],
|
||||||
|
optional_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def transform_embedding_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
request_data: dict = {},
|
||||||
|
optional_params: dict = {},
|
||||||
|
litellm_params: dict = {},
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||||
|
"""
|
||||||
|
OPTIONAL
|
||||||
|
|
||||||
|
Get the complete url for the request
|
||||||
|
|
||||||
|
Some providers need `model` in `api_base`
|
||||||
|
"""
|
||||||
|
return api_base or ""
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"EmbeddingConfig does not need a request transformation for chat models"
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
request_data: dict,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
encoding: Any,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"EmbeddingConfig does not need a response transformation for chat models"
|
||||||
|
)
|
|
@ -9,7 +9,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
convert_content_list_to_str,
|
convert_content_list_to_str,
|
||||||
)
|
)
|
||||||
from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator
|
from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
class ClarifaiError(BaseLLMException):
|
class ClarifaiError(BaseLLMException):
|
||||||
|
|
|
@ -6,7 +6,7 @@ import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
from litellm.llms.base_llm.transformation import (
|
from litellm.llms.base_llm.chat.transformation import (
|
||||||
BaseConfig,
|
BaseConfig,
|
||||||
BaseLLMException,
|
BaseLLMException,
|
||||||
LiteLLMLoggingObj,
|
LiteLLMLoggingObj,
|
||||||
|
|
|
@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional,
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
|
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import ModelResponse, Usage
|
from litellm.types.utils import ModelResponse, Usage
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
|
|
|
@ -6,8 +6,10 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional,
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
|
convert_content_list_to_str,
|
||||||
|
)
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
|
|
|
@ -21,13 +21,15 @@ import litellm.types
|
||||||
import litellm.types.utils
|
import litellm.types.utils
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
|
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
_get_httpx_client,
|
_get_httpx_client,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
|
from litellm.types.utils import EmbeddingResponse
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -403,6 +405,139 @@ class BaseLLMHTTPHandler:
|
||||||
|
|
||||||
return completion_stream, response.headers
|
return completion_stream, response.headers
|
||||||
|
|
||||||
|
def embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: list,
|
||||||
|
timeout: float,
|
||||||
|
custom_llm_provider: str,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
api_base: Optional[str],
|
||||||
|
optional_params: dict,
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
aembedding: bool = False,
|
||||||
|
headers={},
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
|
||||||
|
provider_config = ProviderConfigManager.get_provider_embedding_config(
|
||||||
|
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
|
)
|
||||||
|
# get config from model, custom llm provider
|
||||||
|
headers = provider_config.validate_environment(
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers,
|
||||||
|
model=model,
|
||||||
|
messages=[],
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
api_base = provider_config.get_complete_url(
|
||||||
|
api_base=api_base,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = provider_config.transform_embedding_request(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
optional_params=optional_params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"api_base": api_base,
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if aembedding is True:
|
||||||
|
return self.aembedding( # type: ignore
|
||||||
|
request_data=data,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
provider_config=provider_config,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
sync_httpx_client = _get_httpx_client()
|
||||||
|
else:
|
||||||
|
sync_httpx_client = client
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = sync_httpx_client.post(
|
||||||
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise self._handle_error(
|
||||||
|
e=e,
|
||||||
|
provider_config=provider_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider_config.transform_embedding_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
request_data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aembedding(
|
||||||
|
self,
|
||||||
|
request_data: dict,
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
custom_llm_provider: str,
|
||||||
|
provider_config: BaseEmbeddingConfig,
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||||
|
async_httpx_client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
async_httpx_client = client
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await async_httpx_client.post(
|
||||||
|
url=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(request_data),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise self._handle_error(e=e, provider_config=provider_config)
|
||||||
|
|
||||||
|
return provider_config.transform_embedding_response(
|
||||||
|
model=model,
|
||||||
|
raw_response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
request_data=request_data,
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_error(self, e: Exception, provider_config: BaseConfig):
|
def _handle_error(self, e: Exception, provider_config: BaseConfig):
|
||||||
status_code = getattr(e, "status_code", 500)
|
status_code = getattr(e, "status_code", 500)
|
||||||
error_headers = getattr(e, "headers", None)
|
error_headers = getattr(e, "headers", None)
|
||||||
|
@ -421,6 +556,3 @@ class BaseLLMHTTPHandler:
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
headers=error_headers,
|
headers=error_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def embedding(self):
|
|
||||||
pass
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
)
|
)
|
||||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceError(BaseLLMException):
|
class HuggingfaceError(BaseLLMException):
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any, Callable, List, Optional, Union
|
||||||
from httpx._models import Headers
|
from httpx._models import Headers
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any, Callable, List, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
|
|
|
@ -4,8 +4,10 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
|
convert_content_list_to_str,
|
||||||
|
)
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
class NLPCloudError(BaseLLMException):
|
class NLPCloudError(BaseLLMException):
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
class OllamaError(BaseLLMException):
|
class OllamaError(BaseLLMException):
|
||||||
|
|
|
@ -13,7 +13,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
ollama_pt,
|
ollama_pt,
|
||||||
)
|
)
|
||||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
class OobaboogaError(BaseLLMException):
|
class OobaboogaError(BaseLLMException):
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||||
from litellm.types.utils import ModelResponse
|
from litellm.types.utils import ModelResponse
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
class OpenAIError(BaseLLMException):
|
class OpenAIError(BaseLLMException):
|
||||||
|
|
|
@ -6,7 +6,10 @@ import types
|
||||||
from typing import List, Optional, Union, cast
|
from typing import List, Optional, Union, cast
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
|
convert_content_list_to_str,
|
||||||
|
)
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
AllPromptValues,
|
AllPromptValues,
|
||||||
|
@ -14,7 +17,6 @@ from litellm.types.llms.openai import (
|
||||||
)
|
)
|
||||||
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse
|
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse
|
||||||
|
|
||||||
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
|
|
||||||
from ..chat.gpt_transformation import OpenAIGPTConfig
|
from ..chat.gpt_transformation import OpenAIGPTConfig
|
||||||
from ..common_utils import OpenAIError
|
from ..common_utils import OpenAIError
|
||||||
from .utils import is_tokens_or_list_of_tokens
|
from .utils import is_tokens_or_list_of_tokens
|
||||||
|
|
|
@ -32,7 +32,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
)
|
)
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Union
|
||||||
|
|
||||||
from httpx import Headers
|
from httpx import Headers
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
class PetalsError(BaseLLMException):
|
class PetalsError(BaseLLMException):
|
||||||
|
|
|
@ -4,7 +4,7 @@ from typing import Any, List, Optional, Union
|
||||||
from httpx import Headers, Response
|
from httpx import Headers, Response
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.transformation import (
|
from litellm.llms.base_llm.chat.transformation import (
|
||||||
BaseConfig,
|
BaseConfig,
|
||||||
BaseLLMException,
|
BaseLLMException,
|
||||||
LiteLLMLoggingObj,
|
LiteLLMLoggingObj,
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
from httpx import Headers, Response
|
from httpx import Headers, Response
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import ModelResponse
|
from litellm.types.utils import ModelResponse
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
class PredibaseError(BaseLLMException):
|
class PredibaseError(BaseLLMException):
|
||||||
|
|
|
@ -4,9 +4,14 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
|
convert_content_list_to_str,
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
)
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
custom_prompt,
|
||||||
|
prompt_factory,
|
||||||
|
)
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||||
from litellm.utils import token_counter
|
from litellm.utils import token_counter
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
class ReplicateError(BaseLLMException):
|
class ReplicateError(BaseLLMException):
|
||||||
|
|
|
@ -11,7 +11,7 @@ from typing import Union
|
||||||
|
|
||||||
from httpx._models import Headers
|
from httpx._models import Headers
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
from ..common_utils import SagemakerError
|
from ..common_utils import SagemakerError
|
||||||
|
|
|
@ -4,7 +4,7 @@ from typing import AsyncIterator, Iterator, List, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||||
from litellm.types.utils import StreamingChatCompletionChunk
|
from litellm.types.utils import StreamingChatCompletionChunk
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
)
|
)
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import ModelResponse, Usage
|
from litellm.types.utils import ModelResponse, Usage
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
|
||||||
|
|
||||||
class TritonError(BaseLLMException):
|
class TritonError(BaseLLMException):
|
||||||
|
|
|
@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
from httpx import Headers, Response
|
from httpx import Headers, Response
|
||||||
|
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
|
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
|
||||||
from litellm.llms.base_llm.transformation import (
|
from litellm.llms.base_llm.chat.transformation import (
|
||||||
BaseConfig,
|
BaseConfig,
|
||||||
BaseLLMException,
|
BaseLLMException,
|
||||||
LiteLLMLoggingObj,
|
LiteLLMLoggingObj,
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.types.llms.vertex_ai import PartType
|
from litellm.types.llms.vertex_ai import PartType
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
convert_generic_image_chunk_to_openai_image_obj,
|
convert_generic_image_chunk_to_openai_image_obj,
|
||||||
convert_to_anthropic_image_obj,
|
convert_to_anthropic_image_obj,
|
||||||
)
|
)
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
|
|
141
litellm/llms/voyage/embedding/transformation.py
Normal file
141
litellm/llms/voyage/embedding/transformation.py
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
import json
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.utils import EmbeddingResponse, ModelResponse, Usage
|
||||||
|
|
||||||
|
|
||||||
|
class VoyageError(BaseLLMException):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code: int,
|
||||||
|
message: str,
|
||||||
|
headers: Union[dict, httpx.Headers] = {},
|
||||||
|
):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST", url="https://api.voyageai.com/v1/embeddings"
|
||||||
|
)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
status_code=status_code,
|
||||||
|
message=message,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
|
"""
|
||||||
|
Reference: https://docs.voyageai.com/reference/embeddings-api
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||||
|
if api_base:
|
||||||
|
if not api_base.endswith("/embeddings"):
|
||||||
|
api_base = f"{api_base}/embeddings"
|
||||||
|
return api_base
|
||||||
|
return "https://api.voyageai.com/v1/embeddings"
|
||||||
|
|
||||||
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
|
return [
|
||||||
|
"encoding_format",
|
||||||
|
"dimensions",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Map OpenAI params to Voyage params
|
||||||
|
|
||||||
|
Reference: https://docs.voyageai.com/reference/embeddings-api
|
||||||
|
"""
|
||||||
|
if "encoding_format" in non_default_params:
|
||||||
|
optional_params["encoding_format"] = non_default_params["encoding_format"]
|
||||||
|
if "dimensions" in non_default_params:
|
||||||
|
optional_params["output_dimension"] = non_default_params["dimensions"]
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
if api_key is None:
|
||||||
|
api_key = (
|
||||||
|
get_secret_str("VOYAGE_API_KEY")
|
||||||
|
or get_secret_str("VOYAGE_AI_API_KEY")
|
||||||
|
or get_secret_str("VOYAGE_AI_TOKEN")
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
}
|
||||||
|
|
||||||
|
def transform_embedding_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: Union[str, List[str], List[float], List[List[float]]],
|
||||||
|
optional_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"input": input,
|
||||||
|
"model": model,
|
||||||
|
**optional_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
def transform_embedding_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
request_data: dict = {},
|
||||||
|
optional_params: dict = {},
|
||||||
|
litellm_params: dict = {},
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
try:
|
||||||
|
raw_response_json = raw_response.json()
|
||||||
|
except Exception:
|
||||||
|
raise VoyageError(
|
||||||
|
message=raw_response.text, status_code=raw_response.status_code
|
||||||
|
)
|
||||||
|
|
||||||
|
# model_response.usage
|
||||||
|
model_response.model = raw_response_json.get("model")
|
||||||
|
model_response.data = raw_response_json.get("data")
|
||||||
|
model_response.object = raw_response_json.get("object")
|
||||||
|
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
|
||||||
|
total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0),
|
||||||
|
)
|
||||||
|
model_response.usage = usage
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return VoyageError(
|
||||||
|
message=error_message, status_code=status_code, headers=headers
|
||||||
|
)
|
|
@ -5,7 +5,7 @@ import httpx
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.caching import InMemoryCache
|
from litellm.caching import InMemoryCache
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.watsonx import WatsonXAPIParams
|
from litellm.types.llms.watsonx import WatsonXAPIParams
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,8 @@ from typing import (
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
from litellm.litellm_core_utils.prompt_templates import factory as ptf
|
||||||
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
|
@ -35,8 +36,7 @@ from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
||||||
|
|
||||||
from ...base import BaseLLM
|
from ...base import BaseLLM
|
||||||
from ...base_llm.transformation import BaseConfig
|
from ...base_llm.chat.transformation import BaseConfig
|
||||||
from litellm.litellm_core_utils.prompt_templates import factory as ptf
|
|
||||||
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
|
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
|
@ -3058,6 +3058,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
||||||
or custom_llm_provider == "together_ai"
|
or custom_llm_provider == "together_ai"
|
||||||
or custom_llm_provider == "openai_like"
|
or custom_llm_provider == "openai_like"
|
||||||
or custom_llm_provider == "jina_ai"
|
or custom_llm_provider == "jina_ai"
|
||||||
|
or custom_llm_provider == "voyage"
|
||||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
# Await normally
|
# Await normally
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -3632,10 +3633,10 @@ def embedding( # noqa: PLR0915
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "voyage":
|
elif custom_llm_provider == "voyage":
|
||||||
api_key = api_key or litellm.api_key or get_secret_str("VOYAGE_API_KEY")
|
response = base_llm_http_handler.embedding(
|
||||||
response = openai_chat_completions.embedding(
|
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
|
|
@ -168,7 +168,8 @@ from typing import (
|
||||||
|
|
||||||
from openai import OpenAIError as OriginalError
|
from openai import OpenAIError as OriginalError
|
||||||
|
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig
|
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||||
|
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from .caching.caching import (
|
from .caching.caching import (
|
||||||
|
@ -2377,6 +2378,21 @@ def get_optional_params_embeddings( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
final_params = {**optional_params, **kwargs}
|
final_params = {**optional_params, **kwargs}
|
||||||
return final_params
|
return final_params
|
||||||
|
elif custom_llm_provider == "voyage":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="voyage",
|
||||||
|
request_type="embeddings",
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.VoyageEmbeddingConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params,
|
||||||
|
optional_params={},
|
||||||
|
model=model,
|
||||||
|
drop_params=drop_params if drop_params is not None else False,
|
||||||
|
)
|
||||||
|
final_params = {**optional_params, **kwargs}
|
||||||
|
return final_params
|
||||||
elif custom_llm_provider == "fireworks_ai":
|
elif custom_llm_provider == "fireworks_ai":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -6200,6 +6216,15 @@ class ProviderConfigManager:
|
||||||
return litellm.PetalsConfig()
|
return litellm.PetalsConfig()
|
||||||
return litellm.OpenAIGPTConfig()
|
return litellm.OpenAIGPTConfig()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_provider_embedding_config(
|
||||||
|
model: str,
|
||||||
|
provider: LlmProviders,
|
||||||
|
) -> BaseEmbeddingConfig:
|
||||||
|
if litellm.LlmProviders.VOYAGE == provider:
|
||||||
|
return litellm.VoyageEmbeddingConfig()
|
||||||
|
raise ValueError(f"Provider {provider} does not support embedding config")
|
||||||
|
|
||||||
|
|
||||||
def get_end_user_id_for_cost_tracking(
|
def get_end_user_id_for_cost_tracking(
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
|
|
|
@ -59,6 +59,10 @@ class BaseLLMEmbeddingTest(ABC):
|
||||||
|
|
||||||
print("async embedding response: ", response)
|
print("async embedding response: ", response)
|
||||||
|
|
||||||
|
from openai.types.create_embedding_response import CreateEmbeddingResponse
|
||||||
|
|
||||||
|
CreateEmbeddingResponse.model_validate(response.model_dump())
|
||||||
|
|
||||||
def test_embedding_optional_params_max_retries(self):
|
def test_embedding_optional_params_max_retries(self):
|
||||||
embedding_call_args = self.get_base_embedding_call_args()
|
embedding_call_args = self.get_base_embedding_call_args()
|
||||||
optional_params = get_optional_params_embeddings(
|
optional_params = get_optional_params_embeddings(
|
||||||
|
|
56
tests/llm_translation/test_voyage_ai.py
Normal file
56
tests/llm_translation/test_voyage_ai.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
|
||||||
|
from base_embedding_unit_tests import BaseLLMEmbeddingTest
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
class TestVoyageAI(BaseLLMEmbeddingTest):
|
||||||
|
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||||
|
return litellm.LlmProviders.VOYAGE
|
||||||
|
|
||||||
|
def get_base_embedding_call_args(self) -> dict:
|
||||||
|
return {
|
||||||
|
"model": "voyage/voyage-3-lite",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_voyage_ai_embedding_extra_params():
|
||||||
|
try:
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
with patch.object(client, "post") as mock_client:
|
||||||
|
response = litellm.embedding(
|
||||||
|
model="voyage/voyage-3-lite",
|
||||||
|
input=["a"],
|
||||||
|
dimensions=512,
|
||||||
|
input_type="document",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
json_data = json.loads(mock_client.call_args.kwargs["data"])
|
||||||
|
|
||||||
|
print("request data to voyage ai", json.dumps(json_data, indent=4))
|
||||||
|
|
||||||
|
# Assert the request parameters
|
||||||
|
assert json_data["input"] == ["a"]
|
||||||
|
assert json_data["model"] == "voyage-3-lite"
|
||||||
|
assert json_data["output_dimension"] == 512
|
||||||
|
assert json_data["input_type"] == "document"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
|
@ -348,7 +348,7 @@ async def test_add_and_delete_deployments(llm_router, model_list_flag_value):
|
||||||
|
|
||||||
from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders
|
from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders
|
||||||
from litellm.utils import ProviderConfigManager
|
from litellm.utils import ProviderConfigManager
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig
|
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
def _check_provider_config(config: BaseConfig, provider: LlmProviders):
|
def _check_provider_config(config: BaseConfig, provider: LlmProviders):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue