diff --git a/litellm/__init__.py b/litellm/__init__.py index d806f534f..9afa4aeeb 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -308,7 +308,7 @@ from .utils import ( validate_environment, check_valid_key, get_llm_provider, - completion_with_config + completion_with_config, ) from .main import * # type: ignore from .integrations import * diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index 72b2291b4..191de47ab 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index b103a749a..6281df1f8 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 6a5357ac2..d4e524aeb 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index b0e9f55bb..43ece5947 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -24,7 +24,6 @@ def start_prediction(version_id, input_data, api_token, logging_obj): initial_prediction_data = { "version": version_id, "input": input_data, - "max_new_tokens": 500, } ## LOGGING @@ -109,17 +108,32 @@ def completion( litellm_params=None, logger_fn=None, ): - # Convert messages to prompt - prompt = "" - for message in messages: - prompt += message["content"] - # Start a prediction and get the prediction URL version_id = model_to_version_id(model) - input_data = { - "prompt": prompt, - **optional_params - } + if "meta/llama-2-13b-chat" in model: + system_prompt = "" + prompt = "" + for message in messages: + if message["role"] == "system": + system_prompt = message["content"] + else: + prompt += message["content"] + input_data = { + "system_prompt": system_prompt, + "prompt": prompt, + **optional_params + } + else: + # Convert messages to prompt + prompt = "" + for message in messages: + prompt += message["content"] + + input_data = { + "prompt": prompt, + **optional_params + } + ## COMPLETION CALL ## Replicate Compeltion calls have 2 steps diff --git a/litellm/utils.py b/litellm/utils.py index e25dab254..cd48ab718 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1016,7 +1016,7 @@ def get_optional_params( # use the openai defaults optional_params["logit_bias"] = logit_bias elif custom_llm_provider == "replicate": ## check if unsupported param passed in - supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop"] + supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"] _check_valid_arg(supported_params=supported_params) if stream: @@ -1025,6 +1025,8 @@ def get_optional_params( # use the openai defaults if max_tokens: if "vicuna" in model or "flan" in model: optional_params["max_length"] = max_tokens + elif "meta/codellama-13b" in model: + optional_params["max_tokens"] = max_tokens else: optional_params["max_new_tokens"] = max_tokens if temperature: @@ -1289,7 +1291,6 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None): except Exception as e: raise e - def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]): api_key = (dynamic_api_key or litellm.api_key) # openai diff --git a/pyproject.toml b/pyproject.toml index 8acac6eba..8fc37b24a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.813" +version = "0.1.814" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"