mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
add flan + vicuna + fix replicate errors
This commit is contained in:
parent
48ee4a08ac
commit
1da6026622
3 changed files with 3 additions and 3 deletions
|
@ -39,7 +39,7 @@ def start_prediction(version_id, input_data, api_token, logging_obj):
|
|||
response_data = response.json()
|
||||
return response_data.get("urls", {}).get("get")
|
||||
else:
|
||||
raise ReplicateError(response.status_code, message=response.text)
|
||||
raise ReplicateError(response.status_code, f"Failed to start prediction {response.text}")
|
||||
|
||||
# Function to handle prediction response (non-streaming)
|
||||
def handle_prediction_response(prediction_url, api_token, print_verbose):
|
||||
|
|
|
@ -23,7 +23,7 @@ try:
|
|||
"llama2",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
top_p=0.1,
|
||||
temperature=0.1,
|
||||
temperature=0.01,
|
||||
num_beams=4,
|
||||
max_tokens=60,
|
||||
)
|
||||
|
|
|
@ -706,7 +706,7 @@ def get_optional_params( # use the openai defaults
|
|||
optional_params["stream"] = stream
|
||||
return optional_params
|
||||
if max_tokens != float("inf"):
|
||||
if "vicuna" in model:
|
||||
if "vicuna" in model or "flan" in model:
|
||||
optional_params["max_length"] = max_tokens
|
||||
else:
|
||||
optional_params["max_new_tokens"] = max_tokens
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue