feat(completion()): enable setting prompt templates via completion()

This commit is contained in:
Krrish Dholakia 2023-11-02 16:23:51 -07:00
parent 1fc726d5dd
commit 512a1637eb
9 changed files with 94 additions and 37 deletions

1
.gitignore vendored
View file

@ -13,3 +13,4 @@ litellm/proxy/litellm_secrets.toml
litellm/proxy/api_log.json
.idea/
router_config.yaml
litellm_server/config.yaml

View file

@ -193,9 +193,9 @@ def completion(
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages
)
else:
@ -213,10 +213,12 @@ def completion(
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)

View file

@ -267,6 +267,7 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
bos_open = False
prompt += final_prompt_value
print(f"COMPLETE PROMPT: {prompt}")
return prompt
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):

View file

@ -99,15 +99,18 @@ def completion(
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
print(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
)
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)

View file

@ -257,9 +257,15 @@ def completion(
headers = kwargs.get("headers", None)
num_retries = kwargs.get("num_retries", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
### CUSTOM PROMPT TEMPLATE ###
initial_prompt_value = kwargs.get("intial_prompt_value", None)
roles = kwargs.get("roles", None)
final_prompt_value = kwargs.get("final_prompt_value", None)
bos_token = kwargs.get("bos_token", None)
eos_token = kwargs.get("eos_token", None)
######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response:
@ -280,6 +286,7 @@ def completion(
model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
model_response = ModelResponse()
if kwargs.get('azure', False) == True: # don't remove flag check, to remain backwards compatible for repos like Codium
@ -288,6 +295,19 @@ def completion(
model=deployment_id
custom_llm_provider="azure"
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base)
custom_prompt_dict = None
if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token:
custom_prompt_dict = {model: {}}
if initial_prompt_value:
custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value
if roles:
custom_prompt_dict[model]["roles"] = roles
if final_prompt_value:
custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value
if bos_token:
custom_prompt_dict[model]["bos_token"] = bos_token
if eos_token:
custom_prompt_dict[model]["eos_token"] = eos_token
model_api_key = get_api_key(llm_provider=custom_llm_provider, dynamic_api_key=api_key) # get the api key from the environment if required for the model
if model_api_key and "sk-litellm" in model_api_key:
api_base = "https://proxy.litellm.ai"
@ -646,6 +666,10 @@ def completion(
or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/complete"
)
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = anthropic.completion(
model=model,
messages=messages,
@ -867,6 +891,11 @@ def completion(
headers
or litellm.headers
)
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = huggingface_restapi.completion(
model=model,
messages=messages,
@ -880,7 +909,7 @@ def completion(
encoding=encoding,
api_key=huggingface_key,
logging_obj=logging,
custom_prompt_dict=litellm.custom_prompt_dict
custom_prompt_dict=custom_prompt_dict
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
@ -985,6 +1014,11 @@ def completion(
or get_secret("TOGETHERAI_API_BASE")
or "https://api.together.xyz/inference"
)
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = together_ai.completion(
model=model,
@ -997,7 +1031,8 @@ def completion(
logger_fn=logger_fn,
encoding=encoding,
api_key=together_ai_key,
logging_obj=logging
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict
)
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
# don't try to access stream object,
@ -1129,6 +1164,10 @@ def completion(
response = model_response
elif custom_llm_provider == "bedrock":
# boto3 reads keys from .env
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = bedrock.completion(
model=model,
messages=messages,
@ -1182,9 +1221,13 @@ def completion(
"http://localhost:11434"
)
if model in litellm.custom_prompt_dict:
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
@ -1196,7 +1239,7 @@ def completion(
## LOGGING
logging.pre_call(
input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": litellm.custom_prompt_dict}
input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": custom_prompt_dict}
)
if kwargs.get('acompletion', False) == True:
if optional_params.get("stream", False) == True:

View file

@ -431,7 +431,7 @@ def test_completion_text_openai():
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_text_openai()
# test_completion_text_openai()
def test_completion_openai_with_optional_params():
try:
@ -837,18 +837,16 @@ def test_completion_together_ai():
pytest.fail(f"Error occurred: {e}")
# test_completion_together_ai()
# def test_customprompt_together_ai():
# try:
# litellm.register_prompt_template(
# model="OpenAssistant/llama2-70b-oasst-sft-v10",
# roles={"system":"<|im_start|>system", "assistant":"<|im_start|>assistant", "user":"<|im_start|>user"}, # tell LiteLLM how you want to map the openai messages to this model
# pre_message_sep= "\n",
# post_message_sep= "\n"
# )
# response = completion(model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10", messages=messages)
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
def test_customprompt_together_ai():
try:
litellm.set_verbose = True
response = completion(model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10", messages=messages,
roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}})
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_customprompt_together_ai()
def test_completion_sagemaker():
try:

View file

@ -1997,7 +1997,10 @@ def validate_environment(model: Optional[str]=None) -> dict:
if model is None:
return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys}
## EXTRACT LLM PROVIDER - if model name provided
custom_llm_provider = get_llm_provider(model=model)
try:
custom_llm_provider = get_llm_provider(model=model)
except:
custom_llm_provider = None
# # check if llm provider part of model name
# if model.split("/",1)[0] in litellm.provider_list:
# custom_llm_provider = model.split("/", 1)[0]

View file

@ -21,6 +21,9 @@ ANTHROPIC_API_KEY = ""
COHERE_API_KEY = ""
## CONFIG FILE ##
# CONFIG_FILE_PATH = "" # uncomment to point to config file
## LOGGING ##
SET_VERBOSE = "False" # set to 'True' to see detailed input/output logs

View file

@ -1,11 +1,14 @@
import litellm, os, traceback
import os, traceback
from fastapi import FastAPI, Request, HTTPException
from fastapi.routing import APIRouter
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
import json
import os
import json, sys
from typing import Optional
# sys.path.insert(
# 0, os.path.abspath("../")
# ) # Adds the parent directory to the system path - for litellm local dev
import litellm
try:
from utils import set_callbacks, load_router_config, print_verbose
except ImportError:
@ -31,6 +34,7 @@ llm_model_list: Optional[list] = None
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
if "CONFIG_FILE_PATH" in os.environ:
print(f"CONFIG FILE DETECTED")
llm_router, llm_model_list = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
else:
llm_router, llm_model_list = load_router_config(router=llm_router)
@ -104,7 +108,6 @@ async def chat_completion(request: Request, model: Optional[str] = None):
## CHECK KEYS ##
# default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
env_validation = litellm.validate_environment(model=data["model"])
print(f"request headers: {request.headers}")
if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
if "authorization" in request.headers:
api_key = request.headers.get("authorization")
@ -125,7 +128,6 @@ async def chat_completion(request: Request, model: Optional[str] = None):
for key, value in m["litellm_params"].items():
data[key] = value
break
print(f"data going into litellm completion: {data}")
response = litellm.completion(
**data
)
@ -134,6 +136,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
return response
except Exception as e:
error_traceback = traceback.format_exc()
print(f"{error_traceback}")
error_msg = f"{str(e)}\n\n{error_traceback}"
return {"error": error_msg}
# raise HTTPException(status_code=500, detail=error_msg)