This commit is contained in:
Krrish Dholakia 2024-05-21 17:24:51 -07:00
parent 620e6db027
commit 988970f4c2
3 changed files with 127 additions and 25 deletions

View file

@ -5811,7 +5811,7 @@ def get_optional_params(
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
]:
supported_params += [
supported_params += [ # type: ignore
"functions",
"function_call",
"tools",
@ -6061,6 +6061,47 @@ def get_optional_params(
return optional_params
def get_non_default_params(passed_params: dict) -> dict:
default_params = {
"functions": None,
"function_call": None,
"temperature": None,
"top_p": None,
"n": None,
"stream": None,
"stream_options": None,
"stop": None,
"max_tokens": None,
"presence_penalty": None,
"frequency_penalty": None,
"logit_bias": None,
"user": None,
"model": None,
"custom_llm_provider": "",
"response_format": None,
"seed": None,
"tools": None,
"tool_choice": None,
"max_retries": None,
"logprobs": None,
"top_logprobs": None,
"extra_headers": None,
}
# filter out those parameters that were passed with non-default values
non_default_params = {
k: v
for k, v in passed_params.items()
if (
k != "model"
and k != "custom_llm_provider"
and k in default_params
and v != default_params[k]
)
}
return non_default_params
def calculate_max_parallel_requests(
max_parallel_requests: Optional[int],
rpm: Optional[int],
@ -6287,7 +6328,7 @@ def get_first_chars_messages(kwargs: dict) -> str:
return ""
def get_supported_openai_params(model: str, custom_llm_provider: str):
def get_supported_openai_params(model: str, custom_llm_provider: str) -> Optional[list]:
"""
Returns the supported openai params for a given model + provider
@ -6295,6 +6336,10 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
```
get_supported_openai_params(model="anthropic.claude-3", custom_llm_provider="bedrock")
```
Returns:
- List if custom_llm_provider is mapped
- None if unmapped
"""
if custom_llm_provider == "bedrock":
if model.startswith("anthropic.claude-3"):
@ -6534,6 +6579,8 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
elif custom_llm_provider == "watsonx":
return litellm.IBMWatsonXAIConfig().get_supported_openai_params()
return None
def get_formatted_prompt(
data: dict,