add flan + vicuna + fix replicate errors

This commit is contained in:
ishaan-jaff 2023-09-06 11:21:57 -07:00
parent 48ee4a08ac
commit 1da6026622
3 changed files with 3 additions and 3 deletions

View file

@ -39,7 +39,7 @@ def start_prediction(version_id, input_data, api_token, logging_obj):
response_data = response.json() response_data = response.json()
return response_data.get("urls", {}).get("get") return response_data.get("urls", {}).get("get")
else: else:
raise ReplicateError(response.status_code, message=response.text) raise ReplicateError(response.status_code, f"Failed to start prediction {response.text}")
# Function to handle prediction response (non-streaming) # Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose): def handle_prediction_response(prediction_url, api_token, print_verbose):

View file

@ -23,7 +23,7 @@ try:
"llama2", "llama2",
messages=[{"role": "user", "content": "Hey, how's it going?"}], messages=[{"role": "user", "content": "Hey, how's it going?"}],
top_p=0.1, top_p=0.1,
temperature=0.1, temperature=0.01,
num_beams=4, num_beams=4,
max_tokens=60, max_tokens=60,
) )

View file

@ -706,7 +706,7 @@ def get_optional_params( # use the openai defaults
optional_params["stream"] = stream optional_params["stream"] = stream
return optional_params return optional_params
if max_tokens != float("inf"): if max_tokens != float("inf"):
if "vicuna" in model: if "vicuna" in model or "flan" in model:
optional_params["max_length"] = max_tokens optional_params["max_length"] = max_tokens
else: else:
optional_params["max_new_tokens"] = max_tokens optional_params["max_new_tokens"] = max_tokens