(fix) unable to pass input_type parameter to Voyage AI embedding mode (#7276)

* VoyageEmbeddingConfig

* fix voyage logic to get params

* add voyage embedding transformation

* add get_provider_embedding_config

* use BaseEmbeddingConfig

* voyage clean up

* use llm http handler for embedding transformations

* test_voyage_ai_embedding_extra_params

* add voyage async

* test_voyage_ai_embedding_extra_params

* add async for llm http handler

* update BaseLLMEmbeddingTest

* test_voyage_ai_embedding_extra_params

* fix linting

* fix get_provider_embedding_config

* fix anthropic text test

* update location of base/chat/transformation

* fix import path

* fix IBMWatsonXAIConfig
This commit is contained in:
Ishaan Jaff 2024-12-17 19:23:49 -08:00 committed by GitHub
parent 63172e67f2
commit c7b288ce30
52 changed files with 535 additions and 66 deletions

View file

@ -21,13 +21,15 @@ import litellm.types
import litellm.types.utils
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.utils import EmbeddingResponse
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
if TYPE_CHECKING:
@ -403,6 +405,139 @@ class BaseLLMHTTPHandler:
return completion_stream, response.headers
def embedding(
self,
model: str,
input: list,
timeout: float,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: Optional[str],
optional_params: dict,
model_response: EmbeddingResponse,
api_key: Optional[str] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
aembedding: bool = False,
headers={},
) -> EmbeddingResponse:
provider_config = ProviderConfigManager.get_provider_embedding_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider)
)
# get config from model, custom llm provider
headers = provider_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=[],
optional_params=optional_params,
)
api_base = provider_config.get_complete_url(
api_base=api_base,
model=model,
)
data = provider_config.transform_embedding_request(
model=model,
input=input,
optional_params=optional_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
if aembedding is True:
return self.aembedding( # type: ignore
request_data=data,
api_base=api_base,
headers=headers,
model=model,
custom_llm_provider=custom_llm_provider,
provider_config=provider_config,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
timeout=timeout,
client=client,
)
if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
try:
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(
e=e,
provider_config=provider_config,
)
return provider_config.transform_embedding_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
)
async def aembedding(
self,
request_data: dict,
api_base: str,
headers: dict,
model: str,
custom_llm_provider: str,
provider_config: BaseEmbeddingConfig,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> EmbeddingResponse:
if client is None or not isinstance(client, AsyncHTTPHandler):
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders(custom_llm_provider)
)
else:
async_httpx_client = client
try:
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(request_data),
timeout=timeout,
)
except Exception as e:
raise self._handle_error(e=e, provider_config=provider_config)
return provider_config.transform_embedding_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=request_data,
)
def _handle_error(self, e: Exception, provider_config: BaseConfig):
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
@ -421,6 +556,3 @@ class BaseLLMHTTPHandler:
status_code=status_code,
headers=error_headers,
)
def embedding(self):
pass