fix replicate error

This commit is contained in:
Krrish Dholakia 2023-10-02 21:35:16 -07:00
parent 345a14483e
commit 5a1a1908c1
7 changed files with 29 additions and 14 deletions

View file

@ -308,7 +308,7 @@ from .utils import (
validate_environment, validate_environment,
check_valid_key, check_valid_key,
get_llm_provider, get_llm_provider,
completion_with_config completion_with_config,
) )
from .main import * # type: ignore from .main import * # type: ignore
from .integrations import * from .integrations import *

View file

@ -24,7 +24,6 @@ def start_prediction(version_id, input_data, api_token, logging_obj):
initial_prediction_data = { initial_prediction_data = {
"version": version_id, "version": version_id,
"input": input_data, "input": input_data,
"max_new_tokens": 500,
} }
## LOGGING ## LOGGING
@ -109,18 +108,33 @@ def completion(
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
# Start a prediction and get the prediction URL
version_id = model_to_version_id(model)
if "meta/llama-2-13b-chat" in model:
system_prompt = ""
prompt = ""
for message in messages:
if message["role"] == "system":
system_prompt = message["content"]
else:
prompt += message["content"]
input_data = {
"system_prompt": system_prompt,
"prompt": prompt,
**optional_params
}
else:
# Convert messages to prompt # Convert messages to prompt
prompt = "" prompt = ""
for message in messages: for message in messages:
prompt += message["content"] prompt += message["content"]
# Start a prediction and get the prediction URL
version_id = model_to_version_id(model)
input_data = { input_data = {
"prompt": prompt, "prompt": prompt,
**optional_params **optional_params
} }
## COMPLETION CALL ## COMPLETION CALL
## Replicate Compeltion calls have 2 steps ## Replicate Compeltion calls have 2 steps
## Step1: Start Prediction: gets a prediction url ## Step1: Start Prediction: gets a prediction url

View file

@ -1016,7 +1016,7 @@ def get_optional_params( # use the openai defaults
optional_params["logit_bias"] = logit_bias optional_params["logit_bias"] = logit_bias
elif custom_llm_provider == "replicate": elif custom_llm_provider == "replicate":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop"] supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"]
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if stream: if stream:
@ -1025,6 +1025,8 @@ def get_optional_params( # use the openai defaults
if max_tokens: if max_tokens:
if "vicuna" in model or "flan" in model: if "vicuna" in model or "flan" in model:
optional_params["max_length"] = max_tokens optional_params["max_length"] = max_tokens
elif "meta/codellama-13b" in model:
optional_params["max_tokens"] = max_tokens
else: else:
optional_params["max_new_tokens"] = max_tokens optional_params["max_new_tokens"] = max_tokens
if temperature: if temperature:
@ -1289,7 +1291,6 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
except Exception as e: except Exception as e:
raise e raise e
def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]): def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
api_key = (dynamic_api_key or litellm.api_key) api_key = (dynamic_api_key or litellm.api_key)
# openai # openai

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.813" version = "0.1.814"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"