forked from phoenix/litellm-mirror
style(test_completion.py): fix merge conflict
This commit is contained in:
parent
396d9d8e38
commit
dd7e397650
22 changed files with 1535 additions and 250 deletions
|
@ -1202,8 +1202,6 @@ def get_optional_params( # use the openai defaults
|
|||
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
|
||||
if max_tokens:
|
||||
optional_params["max_tokens_to_sample"] = max_tokens
|
||||
else:
|
||||
optional_params["max_tokens_to_sample"] = 256 # anthropic fails without max_tokens_to_sample
|
||||
if temperature:
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p:
|
||||
|
@ -1226,6 +1224,28 @@ def get_optional_params( # use the openai defaults
|
|||
optional_params["topP"] = top_p
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
elif "cohere" in model: # cohere models on bedrock
|
||||
supported_params = ["stream", "temperature", "max_tokens", "logit_bias", "top_p", "frequency_penalty", "presence_penalty", "stop"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# handle cohere params
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
if temperature:
|
||||
optional_params["temperature"] = temperature
|
||||
if max_tokens:
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
if n:
|
||||
optional_params["num_generations"] = n
|
||||
if logit_bias != {}:
|
||||
optional_params["logit_bias"] = logit_bias
|
||||
if top_p:
|
||||
optional_params["p"] = top_p
|
||||
if frequency_penalty:
|
||||
optional_params["frequency_penalty"] = frequency_penalty
|
||||
if presence_penalty:
|
||||
optional_params["presence_penalty"] = presence_penalty
|
||||
if stop:
|
||||
optional_params["stop_sequences"] = stop
|
||||
elif model in litellm.aleph_alpha_models:
|
||||
supported_params = ["max_tokens", "stream", "top_p", "temperature", "presence_penalty", "frequency_penalty", "n", "stop"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
|
@ -1312,8 +1332,12 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
|
|||
elif model in litellm.cohere_models:
|
||||
custom_llm_provider = "cohere"
|
||||
## replicate
|
||||
elif model in litellm.replicate_models:
|
||||
custom_llm_provider = "replicate"
|
||||
elif model in litellm.replicate_models or ":" in model:
|
||||
model_parts = model.split(":")
|
||||
if len(model_parts) > 1 and len(model_parts[1])==64: ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
|
||||
custom_llm_provider = "replicate"
|
||||
elif model in litellm.replicate_models:
|
||||
custom_llm_provider = "replicate"
|
||||
## openrouter
|
||||
elif model in litellm.openrouter_models:
|
||||
custom_llm_provider = "openrouter"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue