fix(factory.py): add replicate meta llama prompt templating support

This commit is contained in:
Krrish Dholakia 2024-04-25 08:24:28 -07:00
parent 92f21cba30
commit 4f46b4c397
4 changed files with 26 additions and 4 deletions

View file

@ -1355,7 +1355,9 @@ def prompt_factory(
try: try:
if "meta-llama/llama-2" in model and "chat" in model: if "meta-llama/llama-2" in model and "chat" in model:
return llama_2_chat_pt(messages=messages) return llama_2_chat_pt(messages=messages)
elif "meta-llama/llama-3" in model and "instruct" in model: elif (
"meta-llama/llama-3" in model or "meta-llama-3" in model
) and "instruct" in model:
return hf_chat_template( return hf_chat_template(
model=model, model=model,
messages=messages, messages=messages,

View file

@ -307,9 +307,7 @@ def completion(
result, logs = handle_prediction_response( result, logs = handle_prediction_response(
prediction_url, api_key, print_verbose prediction_url, api_key, print_verbose
) )
model_response["ended"] = (
time.time()
) # for pricing this must remain right after calling api
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,

View file

@ -31,6 +31,9 @@ model_list:
- litellm_params: - litellm_params:
model: gpt-3.5-turbo model: gpt-3.5-turbo
model_name: gpt-3.5-turbo model_name: gpt-3.5-turbo
- model_name: llama-3
litellm_params:
model: replicate/meta/meta-llama-3-8b-instruct
router_settings: router_settings:
allowed_fails: 3 allowed_fails: 3
context_window_fallbacks: null context_window_fallbacks: null

View file

@ -1767,6 +1767,25 @@ def test_completion_azure_deployment_id():
# test_completion_anthropic_openai_proxy() # test_completion_anthropic_openai_proxy()
def test_completion_replicate_llama3():
litellm.set_verbose = True
model_name = "replicate/meta/meta-llama-3-8b-instruct"
try:
response = completion(
model=model_name,
messages=messages,
)
print(response)
# Add any assertions here to check the response
response_str = response["choices"][0]["message"]["content"]
print("RESPONSE STRING\n", response_str)
if type(response_str) != str:
pytest.fail(f"Error occurred: {e}")
raise Exception("it worked!")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="replicate endpoints take +2 mins just for this request") @pytest.mark.skip(reason="replicate endpoints take +2 mins just for this request")
def test_completion_replicate_vicuna(): def test_completion_replicate_vicuna():
print("TESTING REPLICATE") print("TESTING REPLICATE")