add flan + vicuna + fix replicate errors

This commit is contained in:
ishaan-jaff 2023-09-06 11:21:57 -07:00
parent 48ee4a08ac
commit 1da6026622
3 changed files with 3 additions and 3 deletions

View file

@ -39,7 +39,7 @@ def start_prediction(version_id, input_data, api_token, logging_obj):
response_data = response.json()
return response_data.get("urls", {}).get("get")
else:
raise ReplicateError(response.status_code, message=response.text)
raise ReplicateError(response.status_code, f"Failed to start prediction {response.text}")
# Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose):

View file

@ -23,7 +23,7 @@ try:
"llama2",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
top_p=0.1,
temperature=0.1,
temperature=0.01,
num_beams=4,
max_tokens=60,
)

View file

@ -706,7 +706,7 @@ def get_optional_params( # use the openai defaults
optional_params["stream"] = stream
return optional_params
if max_tokens != float("inf"):
if "vicuna" in model:
if "vicuna" in model or "flan" in model:
optional_params["max_length"] = max_tokens
else:
optional_params["max_new_tokens"] = max_tokens