mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(completion()): enable setting prompt templates via completion()
This commit is contained in:
parent
1fc726d5dd
commit
512a1637eb
9 changed files with 94 additions and 37 deletions
|
@ -257,9 +257,15 @@ def completion(
|
|||
headers = kwargs.get("headers", None)
|
||||
num_retries = kwargs.get("num_retries", None)
|
||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||
### CUSTOM PROMPT TEMPLATE ###
|
||||
initial_prompt_value = kwargs.get("intial_prompt_value", None)
|
||||
roles = kwargs.get("roles", None)
|
||||
final_prompt_value = kwargs.get("final_prompt_value", None)
|
||||
bos_token = kwargs.get("bos_token", None)
|
||||
eos_token = kwargs.get("eos_token", None)
|
||||
######## end of unpacking kwargs ###########
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
if mock_response:
|
||||
|
@ -280,6 +286,7 @@ def completion(
|
|||
model = litellm.model_alias_map[
|
||||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
|
||||
model_response = ModelResponse()
|
||||
|
||||
if kwargs.get('azure', False) == True: # don't remove flag check, to remain backwards compatible for repos like Codium
|
||||
|
@ -288,6 +295,19 @@ def completion(
|
|||
model=deployment_id
|
||||
custom_llm_provider="azure"
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base)
|
||||
custom_prompt_dict = None
|
||||
if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token:
|
||||
custom_prompt_dict = {model: {}}
|
||||
if initial_prompt_value:
|
||||
custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value
|
||||
if roles:
|
||||
custom_prompt_dict[model]["roles"] = roles
|
||||
if final_prompt_value:
|
||||
custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value
|
||||
if bos_token:
|
||||
custom_prompt_dict[model]["bos_token"] = bos_token
|
||||
if eos_token:
|
||||
custom_prompt_dict[model]["eos_token"] = eos_token
|
||||
model_api_key = get_api_key(llm_provider=custom_llm_provider, dynamic_api_key=api_key) # get the api key from the environment if required for the model
|
||||
if model_api_key and "sk-litellm" in model_api_key:
|
||||
api_base = "https://proxy.litellm.ai"
|
||||
|
@ -646,6 +666,10 @@ def completion(
|
|||
or get_secret("ANTHROPIC_API_BASE")
|
||||
or "https://api.anthropic.com/v1/complete"
|
||||
)
|
||||
custom_prompt_dict = (
|
||||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
model_response = anthropic.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -867,6 +891,11 @@ def completion(
|
|||
headers
|
||||
or litellm.headers
|
||||
)
|
||||
|
||||
custom_prompt_dict = (
|
||||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
model_response = huggingface_restapi.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -880,7 +909,7 @@ def completion(
|
|||
encoding=encoding,
|
||||
api_key=huggingface_key,
|
||||
logging_obj=logging,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict
|
||||
custom_prompt_dict=custom_prompt_dict
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
|
@ -985,6 +1014,11 @@ def completion(
|
|||
or get_secret("TOGETHERAI_API_BASE")
|
||||
or "https://api.together.xyz/inference"
|
||||
)
|
||||
|
||||
custom_prompt_dict = (
|
||||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
|
||||
model_response = together_ai.completion(
|
||||
model=model,
|
||||
|
@ -997,7 +1031,8 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=together_ai_key,
|
||||
logging_obj=logging
|
||||
logging_obj=logging,
|
||||
custom_prompt_dict=custom_prompt_dict
|
||||
)
|
||||
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
|
||||
# don't try to access stream object,
|
||||
|
@ -1129,6 +1164,10 @@ def completion(
|
|||
response = model_response
|
||||
elif custom_llm_provider == "bedrock":
|
||||
# boto3 reads keys from .env
|
||||
custom_prompt_dict = (
|
||||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
model_response = bedrock.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1182,9 +1221,13 @@ def completion(
|
|||
"http://localhost:11434"
|
||||
|
||||
)
|
||||
if model in litellm.custom_prompt_dict:
|
||||
custom_prompt_dict = (
|
||||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = litellm.custom_prompt_dict[model]
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
|
@ -1196,7 +1239,7 @@ def completion(
|
|||
|
||||
## LOGGING
|
||||
logging.pre_call(
|
||||
input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": litellm.custom_prompt_dict}
|
||||
input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": custom_prompt_dict}
|
||||
)
|
||||
if kwargs.get('acompletion', False) == True:
|
||||
if optional_params.get("stream", False) == True:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue