feat(completion()): enable setting prompt templates via completion()

This commit is contained in:
Krrish Dholakia 2023-11-02 16:23:51 -07:00
parent 1fc726d5dd
commit 512a1637eb
9 changed files with 94 additions and 37 deletions

View file

@ -257,9 +257,15 @@ def completion(
headers = kwargs.get("headers", None)
num_retries = kwargs.get("num_retries", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
### CUSTOM PROMPT TEMPLATE ###
initial_prompt_value = kwargs.get("intial_prompt_value", None)
roles = kwargs.get("roles", None)
final_prompt_value = kwargs.get("final_prompt_value", None)
bos_token = kwargs.get("bos_token", None)
eos_token = kwargs.get("eos_token", None)
######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response:
@ -280,6 +286,7 @@ def completion(
model = litellm.model_alias_map[
model
] # update the model to the actual value if an alias has been passed in
model_response = ModelResponse()
if kwargs.get('azure', False) == True: # don't remove flag check, to remain backwards compatible for repos like Codium
@ -288,6 +295,19 @@ def completion(
model=deployment_id
custom_llm_provider="azure"
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base)
custom_prompt_dict = None
if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token:
custom_prompt_dict = {model: {}}
if initial_prompt_value:
custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value
if roles:
custom_prompt_dict[model]["roles"] = roles
if final_prompt_value:
custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value
if bos_token:
custom_prompt_dict[model]["bos_token"] = bos_token
if eos_token:
custom_prompt_dict[model]["eos_token"] = eos_token
model_api_key = get_api_key(llm_provider=custom_llm_provider, dynamic_api_key=api_key) # get the api key from the environment if required for the model
if model_api_key and "sk-litellm" in model_api_key:
api_base = "https://proxy.litellm.ai"
@ -646,6 +666,10 @@ def completion(
or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/complete"
)
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = anthropic.completion(
model=model,
messages=messages,
@ -867,6 +891,11 @@ def completion(
headers
or litellm.headers
)
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = huggingface_restapi.completion(
model=model,
messages=messages,
@ -880,7 +909,7 @@ def completion(
encoding=encoding,
api_key=huggingface_key,
logging_obj=logging,
custom_prompt_dict=litellm.custom_prompt_dict
custom_prompt_dict=custom_prompt_dict
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
@ -985,6 +1014,11 @@ def completion(
or get_secret("TOGETHERAI_API_BASE")
or "https://api.together.xyz/inference"
)
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = together_ai.completion(
model=model,
@ -997,7 +1031,8 @@ def completion(
logger_fn=logger_fn,
encoding=encoding,
api_key=together_ai_key,
logging_obj=logging
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict
)
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
# don't try to access stream object,
@ -1129,6 +1164,10 @@ def completion(
response = model_response
elif custom_llm_provider == "bedrock":
# boto3 reads keys from .env
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = bedrock.completion(
model=model,
messages=messages,
@ -1182,9 +1221,13 @@ def completion(
"http://localhost:11434"
)
if model in litellm.custom_prompt_dict:
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
@ -1196,7 +1239,7 @@ def completion(
## LOGGING
logging.pre_call(
input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": litellm.custom_prompt_dict}
input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": custom_prompt_dict}
)
if kwargs.get('acompletion', False) == True:
if optional_params.get("stream", False) == True: