litellm-mirror/litellm/llms/maritalk.py
Ishaan Jaff c7b288ce30 (fix) unable to pass input_type parameter to Voyage AI embedding mode (#7276)
* VoyageEmbeddingConfig

* fix voyage logic to get params

* add voyage embedding transformation

* add get_provider_embedding_config

* use BaseEmbeddingConfig

* voyage clean up

* use llm http handler for embedding transformations

* test_voyage_ai_embedding_extra_params

* add voyage async

* test_voyage_ai_embedding_extra_params

* add async for llm http handler

* update BaseLLMEmbeddingTest

* test_voyage_ai_embedding_extra_params

* fix linting

* fix get_provider_embedding_config

* fix anthropic text test

* update location of base/chat/transformation

* fix import path

* fix IBMWatsonXAIConfig
2024-12-17 19:23:49 -08:00

74 lines
2.1 KiB
Python

import json
import os
import time
import traceback
import types
from enum import Enum
from typing import Any, Callable, List, Optional, Union
from httpx._models import Headers
import litellm
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.utils import Choices, Message, ModelResponse, Usage
class MaritalkError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, Headers]] = None,
):
super().__init__(status_code=status_code, message=message, headers=headers)
class MaritalkConfig(OpenAIGPTConfig):
def __init__(
self,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[List[str]] = None,
stream: Optional[bool] = None,
stream_options: Optional[dict] = None,
tools: Optional[List[dict]] = None,
tool_choice: Optional[Union[str, dict]] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> List:
return [
"frequency_penalty",
"presence_penalty",
"top_p",
"top_k",
"temperature",
"max_tokens",
"n",
"stop",
"stream",
"stream_options",
"tools",
"tool_choice",
]
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return MaritalkError(
status_code=status_code, message=error_message, headers=headers
)