forked from phoenix/litellm-mirror
fix replicate error
This commit is contained in:
parent
345a14483e
commit
5a1a1908c1
7 changed files with 29 additions and 14 deletions
|
@ -308,7 +308,7 @@ from .utils import (
|
||||||
validate_environment,
|
validate_environment,
|
||||||
check_valid_key,
|
check_valid_key,
|
||||||
get_llm_provider,
|
get_llm_provider,
|
||||||
completion_with_config
|
completion_with_config,
|
||||||
)
|
)
|
||||||
from .main import * # type: ignore
|
from .main import * # type: ignore
|
||||||
from .integrations import *
|
from .integrations import *
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -24,7 +24,6 @@ def start_prediction(version_id, input_data, api_token, logging_obj):
|
||||||
initial_prediction_data = {
|
initial_prediction_data = {
|
||||||
"version": version_id,
|
"version": version_id,
|
||||||
"input": input_data,
|
"input": input_data,
|
||||||
"max_new_tokens": 500,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -109,17 +108,32 @@ def completion(
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
# Convert messages to prompt
|
|
||||||
prompt = ""
|
|
||||||
for message in messages:
|
|
||||||
prompt += message["content"]
|
|
||||||
|
|
||||||
# Start a prediction and get the prediction URL
|
# Start a prediction and get the prediction URL
|
||||||
version_id = model_to_version_id(model)
|
version_id = model_to_version_id(model)
|
||||||
input_data = {
|
if "meta/llama-2-13b-chat" in model:
|
||||||
"prompt": prompt,
|
system_prompt = ""
|
||||||
**optional_params
|
prompt = ""
|
||||||
}
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
system_prompt = message["content"]
|
||||||
|
else:
|
||||||
|
prompt += message["content"]
|
||||||
|
input_data = {
|
||||||
|
"system_prompt": system_prompt,
|
||||||
|
"prompt": prompt,
|
||||||
|
**optional_params
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Convert messages to prompt
|
||||||
|
prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
prompt += message["content"]
|
||||||
|
|
||||||
|
input_data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
**optional_params
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
## Replicate Compeltion calls have 2 steps
|
## Replicate Compeltion calls have 2 steps
|
||||||
|
|
|
@ -1016,7 +1016,7 @@ def get_optional_params( # use the openai defaults
|
||||||
optional_params["logit_bias"] = logit_bias
|
optional_params["logit_bias"] = logit_bias
|
||||||
elif custom_llm_provider == "replicate":
|
elif custom_llm_provider == "replicate":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop"]
|
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"]
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
@ -1025,6 +1025,8 @@ def get_optional_params( # use the openai defaults
|
||||||
if max_tokens:
|
if max_tokens:
|
||||||
if "vicuna" in model or "flan" in model:
|
if "vicuna" in model or "flan" in model:
|
||||||
optional_params["max_length"] = max_tokens
|
optional_params["max_length"] = max_tokens
|
||||||
|
elif "meta/codellama-13b" in model:
|
||||||
|
optional_params["max_tokens"] = max_tokens
|
||||||
else:
|
else:
|
||||||
optional_params["max_new_tokens"] = max_tokens
|
optional_params["max_new_tokens"] = max_tokens
|
||||||
if temperature:
|
if temperature:
|
||||||
|
@ -1289,7 +1291,6 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
|
def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]):
|
||||||
api_key = (dynamic_api_key or litellm.api_key)
|
api_key = (dynamic_api_key or litellm.api_key)
|
||||||
# openai
|
# openai
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "0.1.813"
|
version = "0.1.814"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue