mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
feat(completion()): enable setting prompt templates via completion()
This commit is contained in:
parent
1fc726d5dd
commit
512a1637eb
9 changed files with 94 additions and 37 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -13,3 +13,4 @@ litellm/proxy/litellm_secrets.toml
|
|||
litellm/proxy/api_log.json
|
||||
.idea/
|
||||
router_config.yaml
|
||||
litellm_server/config.yaml
|
||||
|
|
|
@ -193,9 +193,9 @@ def completion(
|
|||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
role_dict=model_prompt_details.get("roles", None),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages
|
||||
)
|
||||
else:
|
||||
|
@ -213,10 +213,12 @@ def completion(
|
|||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
|
|
@ -267,6 +267,7 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
|
|||
bos_open = False
|
||||
|
||||
prompt += final_prompt_value
|
||||
print(f"COMPLETE PROMPT: {prompt}")
|
||||
return prompt
|
||||
|
||||
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
|
||||
|
|
|
@ -99,15 +99,18 @@ def completion(
|
|||
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
print(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages
|
||||
)
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
|
|
|
@ -257,9 +257,15 @@ def completion(
|
|||
headers = kwargs.get("headers", None)
|
||||
num_retries = kwargs.get("num_retries", None)
|
||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||
### CUSTOM PROMPT TEMPLATE ###
|
||||
initial_prompt_value = kwargs.get("intial_prompt_value", None)
|
||||
roles = kwargs.get("roles", None)
|
||||
final_prompt_value = kwargs.get("final_prompt_value", None)
|
||||
bos_token = kwargs.get("bos_token", None)
|
||||
eos_token = kwargs.get("eos_token", None)
|
||||
######## end of unpacking kwargs ###########
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
if mock_response:
|
||||
|
@ -280,6 +286,7 @@ def completion(
|
|||
model = litellm.model_alias_map[
|
||||
model
|
||||
] # update the model to the actual value if an alias has been passed in
|
||||
|
||||
model_response = ModelResponse()
|
||||
|
||||
if kwargs.get('azure', False) == True: # don't remove flag check, to remain backwards compatible for repos like Codium
|
||||
|
@ -288,6 +295,19 @@ def completion(
|
|||
model=deployment_id
|
||||
custom_llm_provider="azure"
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base)
|
||||
custom_prompt_dict = None
|
||||
if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token:
|
||||
custom_prompt_dict = {model: {}}
|
||||
if initial_prompt_value:
|
||||
custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value
|
||||
if roles:
|
||||
custom_prompt_dict[model]["roles"] = roles
|
||||
if final_prompt_value:
|
||||
custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value
|
||||
if bos_token:
|
||||
custom_prompt_dict[model]["bos_token"] = bos_token
|
||||
if eos_token:
|
||||
custom_prompt_dict[model]["eos_token"] = eos_token
|
||||
model_api_key = get_api_key(llm_provider=custom_llm_provider, dynamic_api_key=api_key) # get the api key from the environment if required for the model
|
||||
if model_api_key and "sk-litellm" in model_api_key:
|
||||
api_base = "https://proxy.litellm.ai"
|
||||
|
@ -646,6 +666,10 @@ def completion(
|
|||
or get_secret("ANTHROPIC_API_BASE")
|
||||
or "https://api.anthropic.com/v1/complete"
|
||||
)
|
||||
custom_prompt_dict = (
|
||||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
model_response = anthropic.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -867,6 +891,11 @@ def completion(
|
|||
headers
|
||||
or litellm.headers
|
||||
)
|
||||
|
||||
custom_prompt_dict = (
|
||||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
model_response = huggingface_restapi.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -880,7 +909,7 @@ def completion(
|
|||
encoding=encoding,
|
||||
api_key=huggingface_key,
|
||||
logging_obj=logging,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict
|
||||
custom_prompt_dict=custom_prompt_dict
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
|
@ -985,6 +1014,11 @@ def completion(
|
|||
or get_secret("TOGETHERAI_API_BASE")
|
||||
or "https://api.together.xyz/inference"
|
||||
)
|
||||
|
||||
custom_prompt_dict = (
|
||||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
|
||||
model_response = together_ai.completion(
|
||||
model=model,
|
||||
|
@ -997,7 +1031,8 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=together_ai_key,
|
||||
logging_obj=logging
|
||||
logging_obj=logging,
|
||||
custom_prompt_dict=custom_prompt_dict
|
||||
)
|
||||
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
|
||||
# don't try to access stream object,
|
||||
|
@ -1129,6 +1164,10 @@ def completion(
|
|||
response = model_response
|
||||
elif custom_llm_provider == "bedrock":
|
||||
# boto3 reads keys from .env
|
||||
custom_prompt_dict = (
|
||||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
model_response = bedrock.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -1182,9 +1221,13 @@ def completion(
|
|||
"http://localhost:11434"
|
||||
|
||||
)
|
||||
if model in litellm.custom_prompt_dict:
|
||||
custom_prompt_dict = (
|
||||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = litellm.custom_prompt_dict[model]
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
|
@ -1196,7 +1239,7 @@ def completion(
|
|||
|
||||
## LOGGING
|
||||
logging.pre_call(
|
||||
input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": litellm.custom_prompt_dict}
|
||||
input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": custom_prompt_dict}
|
||||
)
|
||||
if kwargs.get('acompletion', False) == True:
|
||||
if optional_params.get("stream", False) == True:
|
||||
|
|
|
@ -431,7 +431,7 @@ def test_completion_text_openai():
|
|||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
test_completion_text_openai()
|
||||
# test_completion_text_openai()
|
||||
|
||||
def test_completion_openai_with_optional_params():
|
||||
try:
|
||||
|
@ -837,18 +837,16 @@ def test_completion_together_ai():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_together_ai()
|
||||
# def test_customprompt_together_ai():
|
||||
# try:
|
||||
# litellm.register_prompt_template(
|
||||
# model="OpenAssistant/llama2-70b-oasst-sft-v10",
|
||||
# roles={"system":"<|im_start|>system", "assistant":"<|im_start|>assistant", "user":"<|im_start|>user"}, # tell LiteLLM how you want to map the openai messages to this model
|
||||
# pre_message_sep= "\n",
|
||||
# post_message_sep= "\n"
|
||||
# )
|
||||
# response = completion(model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10", messages=messages)
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
def test_customprompt_together_ai():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = completion(model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10", messages=messages,
|
||||
roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}})
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_customprompt_together_ai()
|
||||
|
||||
def test_completion_sagemaker():
|
||||
try:
|
||||
|
|
|
@ -1997,7 +1997,10 @@ def validate_environment(model: Optional[str]=None) -> dict:
|
|||
if model is None:
|
||||
return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys}
|
||||
## EXTRACT LLM PROVIDER - if model name provided
|
||||
custom_llm_provider = get_llm_provider(model=model)
|
||||
try:
|
||||
custom_llm_provider = get_llm_provider(model=model)
|
||||
except:
|
||||
custom_llm_provider = None
|
||||
# # check if llm provider part of model name
|
||||
# if model.split("/",1)[0] in litellm.provider_list:
|
||||
# custom_llm_provider = model.split("/", 1)[0]
|
||||
|
|
|
@ -21,6 +21,9 @@ ANTHROPIC_API_KEY = ""
|
|||
|
||||
COHERE_API_KEY = ""
|
||||
|
||||
## CONFIG FILE ##
|
||||
# CONFIG_FILE_PATH = "" # uncomment to point to config file
|
||||
|
||||
## LOGGING ##
|
||||
|
||||
SET_VERBOSE = "False" # set to 'True' to see detailed input/output logs
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
import litellm, os, traceback
|
||||
import os, traceback
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import json
|
||||
import os
|
||||
import json, sys
|
||||
from typing import Optional
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../")
|
||||
# ) # Adds the parent directory to the system path - for litellm local dev
|
||||
import litellm
|
||||
try:
|
||||
from utils import set_callbacks, load_router_config, print_verbose
|
||||
except ImportError:
|
||||
|
@ -31,6 +34,7 @@ llm_model_list: Optional[list] = None
|
|||
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
|
||||
|
||||
if "CONFIG_FILE_PATH" in os.environ:
|
||||
print(f"CONFIG FILE DETECTED")
|
||||
llm_router, llm_model_list = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
|
||||
else:
|
||||
llm_router, llm_model_list = load_router_config(router=llm_router)
|
||||
|
@ -104,7 +108,6 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
## CHECK KEYS ##
|
||||
# default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
|
||||
env_validation = litellm.validate_environment(model=data["model"])
|
||||
print(f"request headers: {request.headers}")
|
||||
if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
|
||||
if "authorization" in request.headers:
|
||||
api_key = request.headers.get("authorization")
|
||||
|
@ -125,7 +128,6 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
for key, value in m["litellm_params"].items():
|
||||
data[key] = value
|
||||
break
|
||||
print(f"data going into litellm completion: {data}")
|
||||
response = litellm.completion(
|
||||
**data
|
||||
)
|
||||
|
@ -134,6 +136,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
return response
|
||||
except Exception as e:
|
||||
error_traceback = traceback.format_exc()
|
||||
print(f"{error_traceback}")
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
return {"error": error_msg}
|
||||
# raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue