mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix replicate error
This commit is contained in:
parent
345a14483e
commit
5a1a1908c1
7 changed files with 29 additions and 14 deletions
|
@ -308,7 +308,7 @@ from .utils import (
|
|||
validate_environment,
|
||||
check_valid_key,
|
||||
get_llm_provider,
|
||||
completion_with_config
|
||||
completion_with_config,
|
||||
)
|
||||
from .main import * # type: ignore
|
||||
from .integrations import *
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -24,7 +24,6 @@ def start_prediction(version_id, input_data, api_token, logging_obj):
|
|||
initial_prediction_data = {
|
||||
"version": version_id,
|
||||
"input": input_data,
|
||||
"max_new_tokens": 500,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
|
@ -109,17 +108,32 @@ def completion(
|
|||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
# Convert messages to prompt
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# Start a prediction and get the prediction URL
|
||||
version_id = model_to_version_id(model)
|
||||
input_data = {
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
if "meta/llama-2-13b-chat" in model:
|
||||
system_prompt = ""
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
system_prompt = message["content"]
|
||||
else:
|
||||
prompt += message["content"]
|
||||
input_data = {
|
||||
"system_prompt": system_prompt,
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
else:
|
||||
# Convert messages to prompt
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
input_data = {
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
|
||||
|
||||
## COMPLETION CALL
|
||||
## Replicate Compeltion calls have 2 steps
|
||||
|
|
|
@ -1016,7 +1016,7 @@ def get_optional_params( # use the openai defaults
|
|||
optional_params["logit_bias"] = logit_bias
|
||||
elif custom_llm_provider == "replicate":
|
||||
## check if unsupported param passed in
|
||||
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop"]
|
||||
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
|
||||
if stream:
|
||||
|
@ -1025,6 +1025,8 @@ def get_optional_params( # use the openai defaults
|
|||
if max_tokens:
|
||||
if "vicuna" in model or "flan" in model:
|
||||
optional_params["max_length"] = max_tokens
|
||||
elif "meta/codellama-13b" in model:
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
else:
|
||||
optional_params["max_new_tokens"] = max_tokens
|
||||
if temperature:
|
||||
|
@ -1289,7 +1291,6 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
|
||||
api_key = (dynamic_api_key or litellm.api_key)
|
||||
# openai
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "0.1.813"
|
||||
version = "0.1.814"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT License"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue