mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Litellm code qa common config (#7113)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 44s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 44s
* feat(base_llm): initial commit for common base config class Addresses code qa critique https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132 * feat(base_llm/): add transform request/response abstract methods to base config class * feat(cohere-+-clarifai): refactor integrations to use common base config class * fix: fix linting errors * refactor(anthropic/): move anthropic + vertex anthropic to use base config * test: fix xai test * test: fix tests * fix: fix linting errors * test: comment out WIP test * fix(transformation.py): fix is pdf used check * fix: fix linting error
This commit is contained in:
parent
51ead67b4f
commit
5bbf906c83
41 changed files with 1877 additions and 1998 deletions
|
@ -2821,9 +2821,14 @@ def get_optional_params( # noqa: PLR0915
|
|||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.AnthropicConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
messages=messages,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "cohere":
|
||||
## check if unsupported param passed in
|
||||
|
@ -2832,24 +2837,16 @@ def get_optional_params( # noqa: PLR0915
|
|||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# handle cohere params
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if max_tokens is not None:
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
if n is not None:
|
||||
optional_params["num_generations"] = n
|
||||
if logit_bias is not None:
|
||||
optional_params["logit_bias"] = logit_bias
|
||||
if top_p is not None:
|
||||
optional_params["p"] = top_p
|
||||
if frequency_penalty is not None:
|
||||
optional_params["frequency_penalty"] = frequency_penalty
|
||||
if presence_penalty is not None:
|
||||
optional_params["presence_penalty"] = presence_penalty
|
||||
if stop is not None:
|
||||
optional_params["stop_sequences"] = stop
|
||||
optional_params = litellm.CohereConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "cohere_chat":
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -2857,26 +2854,17 @@ def get_optional_params( # noqa: PLR0915
|
|||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# handle cohere params
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if max_tokens is not None:
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
if n is not None:
|
||||
optional_params["num_generations"] = n
|
||||
if top_p is not None:
|
||||
optional_params["p"] = top_p
|
||||
if frequency_penalty is not None:
|
||||
optional_params["frequency_penalty"] = frequency_penalty
|
||||
if presence_penalty is not None:
|
||||
optional_params["presence_penalty"] = presence_penalty
|
||||
if stop is not None:
|
||||
optional_params["stop_sequences"] = stop
|
||||
if tools is not None:
|
||||
optional_params["tools"] = tools
|
||||
if seed is not None:
|
||||
optional_params["seed"] = seed
|
||||
optional_params = litellm.CohereChatConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
|
||||
elif custom_llm_provider == "maritalk":
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -3071,8 +3059,14 @@ def get_optional_params( # noqa: PLR0915
|
|||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.VertexAIAnthropicConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_llama3_models:
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -6220,14 +6214,14 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
|
|||
return messages
|
||||
|
||||
|
||||
from litellm.llms.OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from litellm.llms.base_llm.transformation import BaseConfig
|
||||
|
||||
|
||||
class ProviderConfigManager:
|
||||
@staticmethod
|
||||
def get_provider_config(
|
||||
def get_provider_chat_config(
|
||||
model: str, provider: litellm.LlmProviders
|
||||
) -> OpenAIGPTConfig:
|
||||
) -> BaseConfig:
|
||||
"""
|
||||
Returns the provider config for a given provider.
|
||||
"""
|
||||
|
@ -6239,8 +6233,23 @@ class ProviderConfigManager:
|
|||
return litellm.GroqChatConfig()
|
||||
elif litellm.LlmProviders.DATABRICKS == provider:
|
||||
return litellm.DatabricksConfig()
|
||||
elif litellm.LlmProviders.XAI == provider:
|
||||
return litellm.XAIChatConfig()
|
||||
elif litellm.LlmProviders.TEXT_COMPLETION_OPENAI == provider:
|
||||
return litellm.OpenAITextCompletionConfig()
|
||||
elif litellm.LlmProviders.COHERE_CHAT == provider:
|
||||
return litellm.CohereChatConfig()
|
||||
elif litellm.LlmProviders.COHERE == provider:
|
||||
return litellm.CohereConfig()
|
||||
elif litellm.LlmProviders.CLARIFAI == provider:
|
||||
return litellm.ClarifaiConfig()
|
||||
elif litellm.LlmProviders.ANTHROPIC == provider:
|
||||
return litellm.AnthropicConfig()
|
||||
elif litellm.LlmProviders.VERTEX_AI == provider:
|
||||
if "claude" in model:
|
||||
return litellm.VertexAIAnthropicConfig()
|
||||
|
||||
return OpenAIGPTConfig()
|
||||
return litellm.OpenAIGPTConfig()
|
||||
|
||||
|
||||
def get_end_user_id_for_cost_tracking(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue