mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
add flan + vicuna + fix replicate errors
This commit is contained in:
parent
48ee4a08ac
commit
1da6026622
3 changed files with 3 additions and 3 deletions
|
@ -39,7 +39,7 @@ def start_prediction(version_id, input_data, api_token, logging_obj):
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
return response_data.get("urls", {}).get("get")
|
return response_data.get("urls", {}).get("get")
|
||||||
else:
|
else:
|
||||||
raise ReplicateError(response.status_code, message=response.text)
|
raise ReplicateError(response.status_code, f"Failed to start prediction {response.text}")
|
||||||
|
|
||||||
# Function to handle prediction response (non-streaming)
|
# Function to handle prediction response (non-streaming)
|
||||||
def handle_prediction_response(prediction_url, api_token, print_verbose):
|
def handle_prediction_response(prediction_url, api_token, print_verbose):
|
||||||
|
|
|
@ -23,7 +23,7 @@ try:
|
||||||
"llama2",
|
"llama2",
|
||||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
top_p=0.1,
|
top_p=0.1,
|
||||||
temperature=0.1,
|
temperature=0.01,
|
||||||
num_beams=4,
|
num_beams=4,
|
||||||
max_tokens=60,
|
max_tokens=60,
|
||||||
)
|
)
|
||||||
|
|
|
@ -706,7 +706,7 @@ def get_optional_params( # use the openai defaults
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
return optional_params
|
return optional_params
|
||||||
if max_tokens != float("inf"):
|
if max_tokens != float("inf"):
|
||||||
if "vicuna" in model:
|
if "vicuna" in model or "flan" in model:
|
||||||
optional_params["max_length"] = max_tokens
|
optional_params["max_length"] = max_tokens
|
||||||
else:
|
else:
|
||||||
optional_params["max_new_tokens"] = max_tokens
|
optional_params["max_new_tokens"] = max_tokens
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue