Litellm merge pr (#7161)

* build: merge branch

* test: fix openai naming

* fix(main.py): fix openai renaming

* style: ignore function length for config factory

* fix(sagemaker/): fix routing logic

* fix: fix imports

* fix: fix override
This commit is contained in:
Krish Dholakia 2024-12-10 22:49:26 -08:00 committed by GitHub
parent d5aae81c6d
commit 350cfc36f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
88 changed files with 3617 additions and 4421 deletions

View file

@ -2923,22 +2923,16 @@ def get_optional_params( # noqa: PLR0915
)
_check_valid_arg(supported_params=supported_params)
if stream:
optional_params["stream"] = stream
# return optional_params
if max_tokens is not None:
if "vicuna" in model or "flan" in model:
optional_params["max_length"] = max_tokens
elif "meta/codellama-13b" in model:
optional_params["max_tokens"] = max_tokens
else:
optional_params["max_new_tokens"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if stop is not None:
optional_params["stop_sequences"] = stop
optional_params = litellm.ReplicateConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "predibase":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
@ -2954,7 +2948,14 @@ def get_optional_params( # noqa: PLR0915
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.HuggingfaceConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "together_ai":
## check if unsupported param passed in
@ -2973,53 +2974,6 @@ def get_optional_params( # noqa: PLR0915
else False
),
)
elif custom_llm_provider == "ai21":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if stream:
optional_params["stream"] = stream
if n is not None:
optional_params["numResults"] = n
if max_tokens is not None:
optional_params["maxTokens"] = max_tokens
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["topP"] = top_p
if stop is not None:
optional_params["stopSequences"] = stop
if frequency_penalty is not None:
optional_params["frequencyPenalty"] = {"scale": frequency_penalty}
if presence_penalty is not None:
optional_params["presencePenalty"] = {"scale": presence_penalty}
elif (
custom_llm_provider == "palm"
): # https://developers.generativeai.google/tutorials/curl_quickstart
## check if unsupported param passed in
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if stream:
optional_params["stream"] = stream
if n is not None:
optional_params["candidate_count"] = n
if stop is not None:
if isinstance(stop, str):
optional_params["stop_sequences"] = [stop]
elif isinstance(stop, list):
optional_params["stop_sequences"] = stop
if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens
elif custom_llm_provider == "vertex_ai" and (
model in litellm.vertex_chat_models
or model in litellm.vertex_code_chat_models
@ -3120,12 +3074,25 @@ def get_optional_params( # noqa: PLR0915
_check_valid_arg(supported_params=supported_params)
if "codestral" in model:
optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
else:
optional_params = litellm.MistralConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_ai_ai21_models:
supported_params = get_supported_openai_params(
@ -3326,29 +3293,28 @@ def get_optional_params( # noqa: PLR0915
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "nlp_cloud":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.NLPCloudConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
if max_tokens is not None:
optional_params["max_length"] = max_tokens
if stream:
optional_params["stream"] = stream
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if presence_penalty is not None:
optional_params["presence_penalty"] = presence_penalty
if frequency_penalty is not None:
optional_params["frequency_penalty"] = frequency_penalty
if n is not None:
optional_params["num_return_sequences"] = n
if stop is not None:
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "petals":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
@ -3435,7 +3401,14 @@ def get_optional_params( # noqa: PLR0915
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "text-completion-codestral":
supported_params = get_supported_openai_params(
@ -3443,7 +3416,14 @@ def get_optional_params( # noqa: PLR0915
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "databricks":
@ -3470,6 +3450,11 @@ def get_optional_params( # noqa: PLR0915
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "cerebras":
supported_params = get_supported_openai_params(
@ -3480,6 +3465,11 @@ def get_optional_params( # noqa: PLR0915
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "xai":
supported_params = get_supported_openai_params(
@ -3491,7 +3481,7 @@ def get_optional_params( # noqa: PLR0915
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "ai21_chat":
elif custom_llm_provider == "ai21_chat" or custom_llm_provider == "ai21":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
@ -3500,6 +3490,11 @@ def get_optional_params( # noqa: PLR0915
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "fireworks_ai":
supported_params = get_supported_openai_params(
@ -3525,6 +3520,11 @@ def get_optional_params( # noqa: PLR0915
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "hosted_vllm":
supported_params = get_supported_openai_params(
@ -3594,55 +3594,17 @@ def get_optional_params( # noqa: PLR0915
)
_check_valid_arg(supported_params=supported_params)
if functions is not None:
optional_params["functions"] = functions
if function_call is not None:
optional_params["function_call"] = function_call
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if n is not None:
optional_params["n"] = n
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if presence_penalty is not None:
optional_params["presence_penalty"] = presence_penalty
if frequency_penalty is not None:
optional_params["frequency_penalty"] = frequency_penalty
if logit_bias is not None:
optional_params["logit_bias"] = logit_bias
if user is not None:
optional_params["user"] = user
if response_format is not None:
optional_params["response_format"] = response_format
if seed is not None:
optional_params["seed"] = seed
if tools is not None:
optional_params["tools"] = tools
if tool_choice is not None:
optional_params["tool_choice"] = tool_choice
if max_retries is not None:
optional_params["max_retries"] = max_retries
# OpenRouter-only parameters
extra_body = {}
transforms = passed_params.pop("transforms", None)
models = passed_params.pop("models", None)
route = passed_params.pop("route", None)
if transforms is not None:
extra_body["transforms"] = transforms
if models is not None:
extra_body["models"] = models
if route is not None:
extra_body["route"] = route
optional_params["extra_body"] = (
extra_body # openai client supports `extra_body` param
optional_params = litellm.OpenrouterConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif custom_llm_provider == "watsonx":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
@ -3727,7 +3689,11 @@ def get_optional_params( # noqa: PLR0915
optional_params=optional_params,
model=model,
api_version=api_version, # type: ignore
drop_params=drop_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
else: # assume passing in params for text-completion openai
supported_params = get_supported_openai_params(
@ -6271,7 +6237,7 @@ from litellm.llms.base_llm.transformation import BaseConfig
class ProviderConfigManager:
@staticmethod
def get_provider_chat_config(
def get_provider_chat_config( # noqa: PLR0915
model: str, provider: litellm.LlmProviders
) -> BaseConfig:
"""
@ -6333,6 +6299,60 @@ class ProviderConfigManager:
return litellm.LMStudioChatConfig()
elif litellm.LlmProviders.GALADRIEL == provider:
return litellm.GaladrielChatConfig()
elif litellm.LlmProviders.REPLICATE == provider:
return litellm.ReplicateConfig()
elif litellm.LlmProviders.HUGGINGFACE == provider:
return litellm.HuggingfaceConfig()
elif litellm.LlmProviders.TOGETHER_AI == provider:
return litellm.TogetherAIConfig()
elif litellm.LlmProviders.OPENROUTER == provider:
return litellm.OpenrouterConfig()
elif litellm.LlmProviders.GEMINI == provider:
return litellm.GoogleAIStudioGeminiConfig()
elif (
litellm.LlmProviders.AI21 == provider
or litellm.LlmProviders.AI21_CHAT == provider
):
return litellm.AI21ChatConfig()
elif litellm.LlmProviders.AZURE == provider:
return litellm.AzureOpenAIConfig()
elif litellm.LlmProviders.AZURE_AI == provider:
return litellm.AzureAIStudioConfig()
elif litellm.LlmProviders.AZURE_TEXT == provider:
return litellm.AzureOpenAITextConfig()
elif litellm.LlmProviders.HOSTED_VLLM == provider:
return litellm.HostedVLLMChatConfig()
elif litellm.LlmProviders.NLP_CLOUD == provider:
return litellm.NLPCloudConfig()
elif litellm.LlmProviders.OOBABOOGA == provider:
return litellm.OobaboogaConfig()
elif litellm.LlmProviders.OLLAMA_CHAT == provider:
return litellm.OllamaChatConfig()
elif litellm.LlmProviders.DEEPINFRA == provider:
return litellm.DeepInfraConfig()
elif litellm.LlmProviders.PERPLEXITY == provider:
return litellm.PerplexityChatConfig()
elif (
litellm.LlmProviders.MISTRAL == provider
or litellm.LlmProviders.CODESTRAL == provider
):
return litellm.MistralConfig()
elif litellm.LlmProviders.NVIDIA_NIM == provider:
return litellm.NvidiaNimConfig()
elif litellm.LlmProviders.CEREBRAS == provider:
return litellm.CerebrasConfig()
elif litellm.LlmProviders.VOLCENGINE == provider:
return litellm.VolcEngineConfig()
elif litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL == provider:
return litellm.MistralTextCompletionConfig()
elif litellm.LlmProviders.SAMBANOVA == provider:
return litellm.SambanovaConfig()
elif litellm.LlmProviders.MARITALK == provider:
return litellm.MaritalkConfig()
elif litellm.LlmProviders.CLOUDFLARE == provider:
return litellm.CloudflareChatConfig()
elif litellm.LlmProviders.ANTHROPIC_TEXT == provider:
return litellm.AnthropicTextConfig()
elif litellm.LlmProviders.VLLM == provider:
return litellm.VLLMConfig()
elif litellm.LlmProviders.OLLAMA == provider: