style(test_completion.py): fix merge conflict

This commit is contained in:
Krrish Dholakia 2023-10-05 22:09:38 -07:00
parent 396d9d8e38
commit dd7e397650
22 changed files with 1535 additions and 250 deletions

View file

@ -1202,8 +1202,6 @@ def get_optional_params( # use the openai defaults
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
if max_tokens:
optional_params["max_tokens_to_sample"] = max_tokens
else:
optional_params["max_tokens_to_sample"] = 256 # anthropic fails without max_tokens_to_sample
if temperature:
optional_params["temperature"] = temperature
if top_p:
@ -1226,6 +1224,28 @@ def get_optional_params( # use the openai defaults
optional_params["topP"] = top_p
if stream:
optional_params["stream"] = stream
elif "cohere" in model: # cohere models on bedrock
supported_params = ["stream", "temperature", "max_tokens", "logit_bias", "top_p", "frequency_penalty", "presence_penalty", "stop"]
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream:
optional_params["stream"] = stream
if temperature:
optional_params["temperature"] = temperature
if max_tokens:
optional_params["max_tokens"] = max_tokens
if n:
optional_params["num_generations"] = n
if logit_bias != {}:
optional_params["logit_bias"] = logit_bias
if top_p:
optional_params["p"] = top_p
if frequency_penalty:
optional_params["frequency_penalty"] = frequency_penalty
if presence_penalty:
optional_params["presence_penalty"] = presence_penalty
if stop:
optional_params["stop_sequences"] = stop
elif model in litellm.aleph_alpha_models:
supported_params = ["max_tokens", "stream", "top_p", "temperature", "presence_penalty", "frequency_penalty", "n", "stop"]
_check_valid_arg(supported_params=supported_params)
@ -1312,8 +1332,12 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
elif model in litellm.cohere_models:
custom_llm_provider = "cohere"
## replicate
elif model in litellm.replicate_models:
custom_llm_provider = "replicate"
elif model in litellm.replicate_models or ":" in model:
model_parts = model.split(":")
if len(model_parts) > 1 and len(model_parts[1])==64: ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
custom_llm_provider = "replicate"
elif model in litellm.replicate_models:
custom_llm_provider = "replicate"
## openrouter
elif model in litellm.openrouter_models:
custom_llm_provider = "openrouter"