mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(fix) unable to pass input_type parameter to Voyage AI embedding mode (#7276)
* VoyageEmbeddingConfig * fix voyage logic to get params * add voyage embedding transformation * add get_provider_embedding_config * use BaseEmbeddingConfig * voyage clean up * use llm http handler for embedding transformations * test_voyage_ai_embedding_extra_params * add voyage async * test_voyage_ai_embedding_extra_params * add async for llm http handler * update BaseLLMEmbeddingTest * test_voyage_ai_embedding_extra_params * fix linting * fix get_provider_embedding_config * fix anthropic text test * update location of base/chat/transformation * fix import path * fix IBMWatsonXAIConfig
This commit is contained in:
parent
63172e67f2
commit
c7b288ce30
52 changed files with 535 additions and 66 deletions
|
@ -1083,6 +1083,7 @@ from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
|
|||
from .llms.deepinfra.chat.transformation import DeepInfraConfig
|
||||
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
||||
from .llms.groq.chat.transformation import GroqChatConfig
|
||||
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
||||
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||
from .llms.openai.chat.o1_transformation import (
|
||||
|
|
|
@ -536,14 +536,6 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
|||
) = litellm.XAIChatConfig()._get_openai_compatible_provider_info(
|
||||
api_base, api_key
|
||||
)
|
||||
elif custom_llm_provider == "voyage":
|
||||
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("VOYAGE_API_BASE")
|
||||
or "https://api.voyageai.com/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("VOYAGE_API_KEY")
|
||||
elif custom_llm_provider == "together_ai":
|
||||
api_base = (
|
||||
api_base
|
||||
|
|
|
@ -34,6 +34,8 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
return litellm.OllamaChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "anthropic":
|
||||
return litellm.AnthropicConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "anthropic_text":
|
||||
return litellm.AnthropicTextConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
if request_type == "embeddings":
|
||||
return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params(
|
||||
|
@ -168,10 +170,17 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
]
|
||||
elif custom_llm_provider == "watsonx":
|
||||
return litellm.IBMWatsonXChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "custom_openai" or "text-completion-openai":
|
||||
elif custom_llm_provider == "watsonx_text":
|
||||
return litellm.IBMWatsonXAIConfig().get_supported_openai_params(model=model)
|
||||
elif (
|
||||
custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
):
|
||||
return litellm.OpenAITextCompletionConfig().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
elif custom_llm_provider == "predibase":
|
||||
return litellm.PredibaseConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "voyage":
|
||||
return litellm.VoyageEmbeddingConfig().get_supported_openai_params(model=model)
|
||||
return None
|
||||
|
|
|
@ -21,7 +21,7 @@ import litellm
|
|||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.anthropic import (
|
||||
AllAnthropicToolsValues,
|
||||
AnthropicComputerTool,
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Optional, Union
|
|||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class AnthropicError(BaseLLMException):
|
||||
|
|
|
@ -11,13 +11,16 @@ from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
custom_prompt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.transformation import (
|
||||
from litellm.llms.base_llm.chat.transformation import (
|
||||
BaseConfig,
|
||||
BaseLLMException,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
|
|
|
@ -7,7 +7,7 @@ import litellm
|
|||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
convert_to_azure_openai_messages,
|
||||
)
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ....exceptions import UnsupportedParamsError
|
||||
|
@ -18,7 +18,7 @@ from ....types.llms.openai import (
|
|||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
)
|
||||
from ...base_llm.transformation import BaseConfig
|
||||
from ...base_llm.chat.transformation import BaseConfig
|
||||
from ..common_utils import AzureOpenAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Callable, Optional, Union
|
|||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
|
|
94
litellm/llms/base_llm/embedding/transformation.py
Normal file
94
litellm/llms/base_llm/embedding/transformation.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import EmbeddingResponse, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseEmbeddingConfig(BaseConfig, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, List[str], List[float], List[List[float]]],
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> EmbeddingResponse:
|
||||
return model_response
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"EmbeddingConfig does not need a request transformation for chat models"
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"EmbeddingConfig does not need a response transformation for chat models"
|
||||
)
|
|
@ -9,7 +9,7 @@ from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
|||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import FakeStreamResponseIterator
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class ClarifaiError(BaseLLMException):
|
||||
|
|
|
@ -6,7 +6,7 @@ import httpx
|
|||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.transformation import (
|
||||
from litellm.llms.base_llm.chat.transformation import (
|
||||
BaseConfig,
|
||||
BaseLLMException,
|
||||
LiteLLMLoggingObj,
|
||||
|
|
|
@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional,
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
|
|
|
@ -6,8 +6,10 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional,
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
|
|
|
@ -21,13 +21,15 @@ import litellm.types
|
|||
import litellm.types.utils
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -403,6 +405,139 @@ class BaseLLMHTTPHandler:
|
|||
|
||||
return completion_stream, response.headers
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: list,
|
||||
timeout: float,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: Optional[str],
|
||||
optional_params: dict,
|
||||
model_response: EmbeddingResponse,
|
||||
api_key: Optional[str] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
aembedding: bool = False,
|
||||
headers={},
|
||||
) -> EmbeddingResponse:
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_embedding_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
# get config from model, custom llm provider
|
||||
headers = provider_config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=[],
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
api_base = provider_config.get_complete_url(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
)
|
||||
|
||||
data = provider_config.transform_embedding_request(
|
||||
model=model,
|
||||
input=input,
|
||||
optional_params=optional_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
if aembedding is True:
|
||||
return self.aembedding( # type: ignore
|
||||
request_data=data,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
provider_config=provider_config,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
|
||||
try:
|
||||
response = sync_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
return provider_config.transform_embedding_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
provider_config: BaseEmbeddingConfig,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> EmbeddingResponse:
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
else:
|
||||
async_httpx_client = client
|
||||
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(request_data),
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._handle_error(e=e, provider_config=provider_config)
|
||||
|
||||
return provider_config.transform_embedding_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=request_data,
|
||||
)
|
||||
|
||||
def _handle_error(self, e: Exception, provider_config: BaseConfig):
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
|
@ -421,6 +556,3 @@ class BaseLLMHTTPHandler:
|
|||
status_code=status_code,
|
||||
headers=error_headers,
|
||||
)
|
||||
|
||||
def embedding(self):
|
||||
pass
|
||||
|
|
|
@ -16,7 +16,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
|||
prompt_factory,
|
||||
)
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Literal, Optional, Union
|
|||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class HuggingfaceError(BaseLLMException):
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any, Callable, List, Optional, Union
|
|||
from httpx._models import Headers
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any, Callable, List, Optional, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
|
|
|
@ -4,8 +4,10 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
|
|||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Optional, Union
|
|||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class NLPCloudError(BaseLLMException):
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Union
|
|||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class OllamaError(BaseLLMException):
|
||||
|
|
|
@ -13,7 +13,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
|||
ollama_pt,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Optional, Union
|
|||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class OobaboogaError(BaseLLMException):
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
import httpx
|
||||
import openai
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class OpenAIError(BaseLLMException):
|
||||
|
|
|
@ -6,7 +6,10 @@ import types
|
|||
from typing import List, Optional, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import BaseConfig
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
AllPromptValues,
|
||||
|
@ -14,7 +17,6 @@ from litellm.types.llms.openai import (
|
|||
)
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
|
||||
from ..chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ..common_utils import OpenAIError
|
||||
from .utils import is_tokens_or_list_of_tokens
|
||||
|
|
|
@ -32,7 +32,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
|||
custom_prompt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.utils import (
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Union
|
|||
|
||||
from httpx import Headers
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class PetalsError(BaseLLMException):
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Any, List, Optional, Union
|
|||
from httpx import Headers, Response
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import (
|
||||
from litellm.llms.base_llm.chat.transformation import (
|
||||
BaseConfig,
|
||||
BaseLLMException,
|
||||
LiteLLMLoggingObj,
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
|
|||
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Optional, Union
|
|||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class PredibaseError(BaseLLMException):
|
||||
|
|
|
@ -4,9 +4,14 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import convert_content_list_to_str
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
custom_prompt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||
from litellm.utils import token_counter
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Optional, Union
|
|||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class ReplicateError(BaseLLMException):
|
||||
|
|
|
@ -11,7 +11,7 @@ from typing import Union
|
|||
|
||||
from httpx._models import Headers
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ..common_utils import SagemakerError
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import AsyncIterator, Iterator, List, Optional, Union
|
|||
import httpx
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||
from litellm.types.utils import StreamingChatCompletionChunk
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
|||
custom_prompt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Optional, Union
|
|||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class TritonError(BaseLLMException):
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
from httpx import Headers, Response
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
|
||||
from litellm.llms.base_llm.transformation import (
|
||||
from litellm.llms.base_llm.chat.transformation import (
|
||||
BaseConfig,
|
||||
BaseLLMException,
|
||||
LiteLLMLoggingObj,
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, List, Literal, Optional, Tuple, Union
|
|||
import httpx
|
||||
|
||||
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.vertex_ai import PartType
|
||||
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
|||
convert_generic_image_chunk_to_openai_image_obj,
|
||||
convert_to_anthropic_image_obj,
|
||||
)
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
|
|
141
litellm/llms/voyage/embedding/transformation.py
Normal file
141
litellm/llms/voyage/embedding/transformation.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
import json
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import EmbeddingResponse, ModelResponse, Usage
|
||||
|
||||
|
||||
class VoyageError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Union[dict, httpx.Headers] = {},
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.voyageai.com/v1/embeddings"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
||||
"""
|
||||
Reference: https://docs.voyageai.com/reference/embeddings-api
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
if api_base:
|
||||
if not api_base.endswith("/embeddings"):
|
||||
api_base = f"{api_base}/embeddings"
|
||||
return api_base
|
||||
return "https://api.voyageai.com/v1/embeddings"
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return [
|
||||
"encoding_format",
|
||||
"dimensions",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI params to Voyage params
|
||||
|
||||
Reference: https://docs.voyageai.com/reference/embeddings-api
|
||||
"""
|
||||
if "encoding_format" in non_default_params:
|
||||
optional_params["encoding_format"] = non_default_params["encoding_format"]
|
||||
if "dimensions" in non_default_params:
|
||||
optional_params["output_dimension"] = non_default_params["dimensions"]
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret_str("VOYAGE_API_KEY")
|
||||
or get_secret_str("VOYAGE_AI_API_KEY")
|
||||
or get_secret_str("VOYAGE_AI_TOKEN")
|
||||
)
|
||||
return {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, List[str], List[float], List[List[float]]],
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {
|
||||
"input": input,
|
||||
"model": model,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> EmbeddingResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise VoyageError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
# model_response.usage
|
||||
model_response.model = raw_response_json.get("model")
|
||||
model_response.data = raw_response_json.get("data")
|
||||
model_response.object = raw_response_json.get("object")
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0),
|
||||
total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0),
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return VoyageError(
|
||||
message=error_message, status_code=status_code, headers=headers
|
||||
)
|
|
@ -5,7 +5,7 @@ import httpx
|
|||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.watsonx import WatsonXAPIParams
|
||||
|
||||
|
|
|
@ -24,7 +24,8 @@ from typing import (
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.transformation import BaseLLMException
|
||||
from litellm.litellm_core_utils.prompt_templates import factory as ptf
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
|
@ -35,8 +36,7 @@ from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
|||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
||||
|
||||
from ...base import BaseLLM
|
||||
from ...base_llm.transformation import BaseConfig
|
||||
from litellm.litellm_core_utils.prompt_templates import factory as ptf
|
||||
from ...base_llm.chat.transformation import BaseConfig
|
||||
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
|
@ -3058,6 +3058,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
|||
or custom_llm_provider == "together_ai"
|
||||
or custom_llm_provider == "openai_like"
|
||||
or custom_llm_provider == "jina_ai"
|
||||
or custom_llm_provider == "voyage"
|
||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -3632,10 +3633,10 @@ def embedding( # noqa: PLR0915
|
|||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "voyage":
|
||||
api_key = api_key or litellm.api_key or get_secret_str("VOYAGE_API_KEY")
|
||||
response = openai_chat_completions.embedding(
|
||||
response = base_llm_http_handler.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
|
|
|
@ -168,7 +168,8 @@ from typing import (
|
|||
|
||||
from openai import OpenAIError as OriginalError
|
||||
|
||||
from litellm.llms.base_llm.transformation import BaseConfig
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
|
||||
from ._logging import verbose_logger
|
||||
from .caching.caching import (
|
||||
|
@ -2377,6 +2378,21 @@ def get_optional_params_embeddings( # noqa: PLR0915
|
|||
)
|
||||
final_params = {**optional_params, **kwargs}
|
||||
return final_params
|
||||
elif custom_llm_provider == "voyage":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model,
|
||||
custom_llm_provider="voyage",
|
||||
request_type="embeddings",
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.VoyageEmbeddingConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params={},
|
||||
model=model,
|
||||
drop_params=drop_params if drop_params is not None else False,
|
||||
)
|
||||
final_params = {**optional_params, **kwargs}
|
||||
return final_params
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model,
|
||||
|
@ -6200,6 +6216,15 @@ class ProviderConfigManager:
|
|||
return litellm.PetalsConfig()
|
||||
return litellm.OpenAIGPTConfig()
|
||||
|
||||
@staticmethod
|
||||
def get_provider_embedding_config(
|
||||
model: str,
|
||||
provider: LlmProviders,
|
||||
) -> BaseEmbeddingConfig:
|
||||
if litellm.LlmProviders.VOYAGE == provider:
|
||||
return litellm.VoyageEmbeddingConfig()
|
||||
raise ValueError(f"Provider {provider} does not support embedding config")
|
||||
|
||||
|
||||
def get_end_user_id_for_cost_tracking(
|
||||
litellm_params: dict,
|
||||
|
|
|
@ -59,6 +59,10 @@ class BaseLLMEmbeddingTest(ABC):
|
|||
|
||||
print("async embedding response: ", response)
|
||||
|
||||
from openai.types.create_embedding_response import CreateEmbeddingResponse
|
||||
|
||||
CreateEmbeddingResponse.model_validate(response.model_dump())
|
||||
|
||||
def test_embedding_optional_params_max_retries(self):
|
||||
embedding_call_args = self.get_base_embedding_call_args()
|
||||
optional_params = get_optional_params_embeddings(
|
||||
|
|
56
tests/llm_translation/test_voyage_ai.py
Normal file
56
tests/llm_translation/test_voyage_ai.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
from base_embedding_unit_tests import BaseLLMEmbeddingTest
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestVoyageAI(BaseLLMEmbeddingTest):
|
||||
def get_custom_llm_provider(self) -> litellm.LlmProviders:
|
||||
return litellm.LlmProviders.VOYAGE
|
||||
|
||||
def get_base_embedding_call_args(self) -> dict:
|
||||
return {
|
||||
"model": "voyage/voyage-3-lite",
|
||||
}
|
||||
|
||||
|
||||
def test_voyage_ai_embedding_extra_params():
|
||||
try:
|
||||
|
||||
client = HTTPHandler()
|
||||
litellm.set_verbose = True
|
||||
|
||||
with patch.object(client, "post") as mock_client:
|
||||
response = litellm.embedding(
|
||||
model="voyage/voyage-3-lite",
|
||||
input=["a"],
|
||||
dimensions=512,
|
||||
input_type="document",
|
||||
client=client,
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
json_data = json.loads(mock_client.call_args.kwargs["data"])
|
||||
|
||||
print("request data to voyage ai", json.dumps(json_data, indent=4))
|
||||
|
||||
# Assert the request parameters
|
||||
assert json_data["input"] == ["a"]
|
||||
assert json_data["model"] == "voyage-3-lite"
|
||||
assert json_data["output_dimension"] == 512
|
||||
assert json_data["input_type"] == "document"
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
|
@ -348,7 +348,7 @@ async def test_add_and_delete_deployments(llm_router, model_list_flag_value):
|
|||
|
||||
from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
from litellm.llms.base_llm.transformation import BaseConfig
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
|
||||
|
||||
def _check_provider_config(config: BaseConfig, provider: LlmProviders):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue