feat(completion()): enable setting prompt templates via completion()

This commit is contained in:
Krrish Dholakia 2023-11-02 16:23:51 -07:00
parent 1fc726d5dd
commit 512a1637eb
9 changed files with 94 additions and 37 deletions

1
.gitignore vendored
View file

@ -13,3 +13,4 @@ litellm/proxy/litellm_secrets.toml
litellm/proxy/api_log.json litellm/proxy/api_log.json
.idea/ .idea/
router_config.yaml router_config.yaml
litellm_server/config.yaml

View file

@ -193,9 +193,9 @@ def completion(
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details.get("roles", None),
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages messages=messages
) )
else: else:
@ -213,10 +213,12 @@ def completion(
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
) )
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)

View file

@ -267,6 +267,7 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
bos_open = False bos_open = False
prompt += final_prompt_value prompt += final_prompt_value
print(f"COMPLETE PROMPT: {prompt}")
return prompt return prompt
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None): def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):

View file

@ -99,15 +99,18 @@ def completion(
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
print(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages bos_token=model_prompt_details.get("bos_token", ""),
) eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)

View file

@ -257,9 +257,15 @@ def completion(
headers = kwargs.get("headers", None) headers = kwargs.get("headers", None)
num_retries = kwargs.get("num_retries", None) num_retries = kwargs.get("num_retries", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
### CUSTOM PROMPT TEMPLATE ###
initial_prompt_value = kwargs.get("intial_prompt_value", None)
roles = kwargs.get("roles", None)
final_prompt_value = kwargs.get("final_prompt_value", None)
bos_token = kwargs.get("bos_token", None)
eos_token = kwargs.get("eos_token", None)
######## end of unpacking kwargs ########### ######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"] openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict"] litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token"]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response: if mock_response:
@ -280,6 +286,7 @@ def completion(
model = litellm.model_alias_map[ model = litellm.model_alias_map[
model model
] # update the model to the actual value if an alias has been passed in ] # update the model to the actual value if an alias has been passed in
model_response = ModelResponse() model_response = ModelResponse()
if kwargs.get('azure', False) == True: # don't remove flag check, to remain backwards compatible for repos like Codium if kwargs.get('azure', False) == True: # don't remove flag check, to remain backwards compatible for repos like Codium
@ -288,6 +295,19 @@ def completion(
model=deployment_id model=deployment_id
custom_llm_provider="azure" custom_llm_provider="azure"
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base)
custom_prompt_dict = None
if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token:
custom_prompt_dict = {model: {}}
if initial_prompt_value:
custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value
if roles:
custom_prompt_dict[model]["roles"] = roles
if final_prompt_value:
custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value
if bos_token:
custom_prompt_dict[model]["bos_token"] = bos_token
if eos_token:
custom_prompt_dict[model]["eos_token"] = eos_token
model_api_key = get_api_key(llm_provider=custom_llm_provider, dynamic_api_key=api_key) # get the api key from the environment if required for the model model_api_key = get_api_key(llm_provider=custom_llm_provider, dynamic_api_key=api_key) # get the api key from the environment if required for the model
if model_api_key and "sk-litellm" in model_api_key: if model_api_key and "sk-litellm" in model_api_key:
api_base = "https://proxy.litellm.ai" api_base = "https://proxy.litellm.ai"
@ -646,6 +666,10 @@ def completion(
or get_secret("ANTHROPIC_API_BASE") or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/complete" or "https://api.anthropic.com/v1/complete"
) )
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = anthropic.completion( model_response = anthropic.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -867,6 +891,11 @@ def completion(
headers headers
or litellm.headers or litellm.headers
) )
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = huggingface_restapi.completion( model_response = huggingface_restapi.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -880,7 +909,7 @@ def completion(
encoding=encoding, encoding=encoding,
api_key=huggingface_key, api_key=huggingface_key,
logging_obj=logging, logging_obj=logging,
custom_prompt_dict=litellm.custom_prompt_dict custom_prompt_dict=custom_prompt_dict
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object, # don't try to access stream object,
@ -985,6 +1014,11 @@ def completion(
or get_secret("TOGETHERAI_API_BASE") or get_secret("TOGETHERAI_API_BASE")
or "https://api.together.xyz/inference" or "https://api.together.xyz/inference"
) )
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = together_ai.completion( model_response = together_ai.completion(
model=model, model=model,
@ -997,7 +1031,8 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
api_key=together_ai_key, api_key=together_ai_key,
logging_obj=logging logging_obj=logging,
custom_prompt_dict=custom_prompt_dict
) )
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
# don't try to access stream object, # don't try to access stream object,
@ -1129,6 +1164,10 @@ def completion(
response = model_response response = model_response
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
# boto3 reads keys from .env # boto3 reads keys from .env
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = bedrock.completion( model_response = bedrock.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -1182,9 +1221,13 @@ def completion(
"http://localhost:11434" "http://localhost:11434"
) )
if model in litellm.custom_prompt_dict: custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
@ -1196,7 +1239,7 @@ def completion(
## LOGGING ## LOGGING
logging.pre_call( logging.pre_call(
input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": litellm.custom_prompt_dict} input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": custom_prompt_dict}
) )
if kwargs.get('acompletion', False) == True: if kwargs.get('acompletion', False) == True:
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) == True:

View file

@ -431,7 +431,7 @@ def test_completion_text_openai():
print(response) print(response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_text_openai() # test_completion_text_openai()
def test_completion_openai_with_optional_params(): def test_completion_openai_with_optional_params():
try: try:
@ -837,18 +837,16 @@ def test_completion_together_ai():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_together_ai() # test_completion_together_ai()
# def test_customprompt_together_ai(): def test_customprompt_together_ai():
# try: try:
# litellm.register_prompt_template( litellm.set_verbose = True
# model="OpenAssistant/llama2-70b-oasst-sft-v10", response = completion(model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10", messages=messages,
# roles={"system":"<|im_start|>system", "assistant":"<|im_start|>assistant", "user":"<|im_start|>user"}, # tell LiteLLM how you want to map the openai messages to this model roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}})
# pre_message_sep= "\n", print(response)
# post_message_sep= "\n" except Exception as e:
# ) pytest.fail(f"Error occurred: {e}")
# response = completion(model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10", messages=messages)
# print(response) test_customprompt_together_ai()
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
def test_completion_sagemaker(): def test_completion_sagemaker():
try: try:

View file

@ -1997,7 +1997,10 @@ def validate_environment(model: Optional[str]=None) -> dict:
if model is None: if model is None:
return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys} return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys}
## EXTRACT LLM PROVIDER - if model name provided ## EXTRACT LLM PROVIDER - if model name provided
custom_llm_provider = get_llm_provider(model=model) try:
custom_llm_provider = get_llm_provider(model=model)
except:
custom_llm_provider = None
# # check if llm provider part of model name # # check if llm provider part of model name
# if model.split("/",1)[0] in litellm.provider_list: # if model.split("/",1)[0] in litellm.provider_list:
# custom_llm_provider = model.split("/", 1)[0] # custom_llm_provider = model.split("/", 1)[0]

View file

@ -21,6 +21,9 @@ ANTHROPIC_API_KEY = ""
COHERE_API_KEY = "" COHERE_API_KEY = ""
## CONFIG FILE ##
# CONFIG_FILE_PATH = "" # uncomment to point to config file
## LOGGING ## ## LOGGING ##
SET_VERBOSE = "False" # set to 'True' to see detailed input/output logs SET_VERBOSE = "False" # set to 'True' to see detailed input/output logs

View file

@ -1,11 +1,14 @@
import litellm, os, traceback import os, traceback
from fastapi import FastAPI, Request, HTTPException from fastapi import FastAPI, Request, HTTPException
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.responses import StreamingResponse, FileResponse from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import json import json, sys
import os
from typing import Optional from typing import Optional
# sys.path.insert(
# 0, os.path.abspath("../")
# ) # Adds the parent directory to the system path - for litellm local dev
import litellm
try: try:
from utils import set_callbacks, load_router_config, print_verbose from utils import set_callbacks, load_router_config, print_verbose
except ImportError: except ImportError:
@ -31,6 +34,7 @@ llm_model_list: Optional[list] = None
set_callbacks() # sets litellm callbacks for logging if they exist in the environment set_callbacks() # sets litellm callbacks for logging if they exist in the environment
if "CONFIG_FILE_PATH" in os.environ: if "CONFIG_FILE_PATH" in os.environ:
print(f"CONFIG FILE DETECTED")
llm_router, llm_model_list = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH")) llm_router, llm_model_list = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
else: else:
llm_router, llm_model_list = load_router_config(router=llm_router) llm_router, llm_model_list = load_router_config(router=llm_router)
@ -104,7 +108,6 @@ async def chat_completion(request: Request, model: Optional[str] = None):
## CHECK KEYS ## ## CHECK KEYS ##
# default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
env_validation = litellm.validate_environment(model=data["model"]) env_validation = litellm.validate_environment(model=data["model"])
print(f"request headers: {request.headers}")
if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
if "authorization" in request.headers: if "authorization" in request.headers:
api_key = request.headers.get("authorization") api_key = request.headers.get("authorization")
@ -125,7 +128,6 @@ async def chat_completion(request: Request, model: Optional[str] = None):
for key, value in m["litellm_params"].items(): for key, value in m["litellm_params"].items():
data[key] = value data[key] = value
break break
print(f"data going into litellm completion: {data}")
response = litellm.completion( response = litellm.completion(
**data **data
) )
@ -134,6 +136,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
return response return response
except Exception as e: except Exception as e:
error_traceback = traceback.format_exc() error_traceback = traceback.format_exc()
print(f"{error_traceback}")
error_msg = f"{str(e)}\n\n{error_traceback}" error_msg = f"{str(e)}\n\n{error_traceback}"
return {"error": error_msg} return {"error": error_msg}
# raise HTTPException(status_code=500, detail=error_msg) # raise HTTPException(status_code=500, detail=error_msg)