forked from phoenix/litellm-mirror
feat(completion()): enable setting prompt templates via completion()
This commit is contained in:
parent
1fc726d5dd
commit
512a1637eb
9 changed files with 94 additions and 37 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -13,3 +13,4 @@ litellm/proxy/litellm_secrets.toml
|
||||||
litellm/proxy/api_log.json
|
litellm/proxy/api_log.json
|
||||||
.idea/
|
.idea/
|
||||||
router_config.yaml
|
router_config.yaml
|
||||||
|
litellm_server/config.yaml
|
||||||
|
|
|
@ -193,9 +193,9 @@ def completion(
|
||||||
# check if the model has a registered custom prompt
|
# check if the model has a registered custom prompt
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
prompt = custom_prompt(
|
prompt = custom_prompt(
|
||||||
role_dict=model_prompt_details["roles"],
|
role_dict=model_prompt_details.get("roles", None),
|
||||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||||
messages=messages
|
messages=messages
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -213,10 +213,12 @@ def completion(
|
||||||
# check if the model has a registered custom prompt
|
# check if the model has a registered custom prompt
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
prompt = custom_prompt(
|
prompt = custom_prompt(
|
||||||
role_dict=model_prompt_details["roles"],
|
role_dict=model_prompt_details.get("roles", {}),
|
||||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||||
messages=messages
|
bos_token=model_prompt_details.get("bos_token", ""),
|
||||||
|
eos_token=model_prompt_details.get("eos_token", ""),
|
||||||
|
messages=messages,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = prompt_factory(model=model, messages=messages)
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
|
|
@ -267,6 +267,7 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
|
||||||
bos_open = False
|
bos_open = False
|
||||||
|
|
||||||
prompt += final_prompt_value
|
prompt += final_prompt_value
|
||||||
|
print(f"COMPLETE PROMPT: {prompt}")
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
|
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
|
||||||
|
|
|
@ -99,15 +99,18 @@ def completion(
|
||||||
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
|
if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
|
print(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}")
|
||||||
if model in custom_prompt_dict:
|
if model in custom_prompt_dict:
|
||||||
# check if the model has a registered custom prompt
|
# check if the model has a registered custom prompt
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
prompt = custom_prompt(
|
prompt = custom_prompt(
|
||||||
role_dict=model_prompt_details["roles"],
|
role_dict=model_prompt_details.get("roles", {}),
|
||||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||||
messages=messages
|
bos_token=model_prompt_details.get("bos_token", ""),
|
||||||
)
|
eos_token=model_prompt_details.get("eos_token", ""),
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt = prompt_factory(model=model, messages=messages)
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
|
||||||
|
|
|
@ -257,9 +257,15 @@ def completion(
|
||||||
headers = kwargs.get("headers", None)
|
headers = kwargs.get("headers", None)
|
||||||
num_retries = kwargs.get("num_retries", None)
|
num_retries = kwargs.get("num_retries", None)
|
||||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
|
||||||
|
### CUSTOM PROMPT TEMPLATE ###
|
||||||
|
initial_prompt_value = kwargs.get("intial_prompt_value", None)
|
||||||
|
roles = kwargs.get("roles", None)
|
||||||
|
final_prompt_value = kwargs.get("final_prompt_value", None)
|
||||||
|
bos_token = kwargs.get("bos_token", None)
|
||||||
|
eos_token = kwargs.get("eos_token", None)
|
||||||
######## end of unpacking kwargs ###########
|
######## end of unpacking kwargs ###########
|
||||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
|
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
|
||||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict"]
|
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token"]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||||
if mock_response:
|
if mock_response:
|
||||||
|
@ -280,6 +286,7 @@ def completion(
|
||||||
model = litellm.model_alias_map[
|
model = litellm.model_alias_map[
|
||||||
model
|
model
|
||||||
] # update the model to the actual value if an alias has been passed in
|
] # update the model to the actual value if an alias has been passed in
|
||||||
|
|
||||||
model_response = ModelResponse()
|
model_response = ModelResponse()
|
||||||
|
|
||||||
if kwargs.get('azure', False) == True: # don't remove flag check, to remain backwards compatible for repos like Codium
|
if kwargs.get('azure', False) == True: # don't remove flag check, to remain backwards compatible for repos like Codium
|
||||||
|
@ -288,6 +295,19 @@ def completion(
|
||||||
model=deployment_id
|
model=deployment_id
|
||||||
custom_llm_provider="azure"
|
custom_llm_provider="azure"
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base)
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base)
|
||||||
|
custom_prompt_dict = None
|
||||||
|
if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token:
|
||||||
|
custom_prompt_dict = {model: {}}
|
||||||
|
if initial_prompt_value:
|
||||||
|
custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value
|
||||||
|
if roles:
|
||||||
|
custom_prompt_dict[model]["roles"] = roles
|
||||||
|
if final_prompt_value:
|
||||||
|
custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value
|
||||||
|
if bos_token:
|
||||||
|
custom_prompt_dict[model]["bos_token"] = bos_token
|
||||||
|
if eos_token:
|
||||||
|
custom_prompt_dict[model]["eos_token"] = eos_token
|
||||||
model_api_key = get_api_key(llm_provider=custom_llm_provider, dynamic_api_key=api_key) # get the api key from the environment if required for the model
|
model_api_key = get_api_key(llm_provider=custom_llm_provider, dynamic_api_key=api_key) # get the api key from the environment if required for the model
|
||||||
if model_api_key and "sk-litellm" in model_api_key:
|
if model_api_key and "sk-litellm" in model_api_key:
|
||||||
api_base = "https://proxy.litellm.ai"
|
api_base = "https://proxy.litellm.ai"
|
||||||
|
@ -646,6 +666,10 @@ def completion(
|
||||||
or get_secret("ANTHROPIC_API_BASE")
|
or get_secret("ANTHROPIC_API_BASE")
|
||||||
or "https://api.anthropic.com/v1/complete"
|
or "https://api.anthropic.com/v1/complete"
|
||||||
)
|
)
|
||||||
|
custom_prompt_dict = (
|
||||||
|
custom_prompt_dict
|
||||||
|
or litellm.custom_prompt_dict
|
||||||
|
)
|
||||||
model_response = anthropic.completion(
|
model_response = anthropic.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -867,6 +891,11 @@ def completion(
|
||||||
headers
|
headers
|
||||||
or litellm.headers
|
or litellm.headers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
custom_prompt_dict = (
|
||||||
|
custom_prompt_dict
|
||||||
|
or litellm.custom_prompt_dict
|
||||||
|
)
|
||||||
model_response = huggingface_restapi.completion(
|
model_response = huggingface_restapi.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -880,7 +909,7 @@ def completion(
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
api_key=huggingface_key,
|
api_key=huggingface_key,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
custom_prompt_dict=litellm.custom_prompt_dict
|
custom_prompt_dict=custom_prompt_dict
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
|
@ -985,6 +1014,11 @@ def completion(
|
||||||
or get_secret("TOGETHERAI_API_BASE")
|
or get_secret("TOGETHERAI_API_BASE")
|
||||||
or "https://api.together.xyz/inference"
|
or "https://api.together.xyz/inference"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
custom_prompt_dict = (
|
||||||
|
custom_prompt_dict
|
||||||
|
or litellm.custom_prompt_dict
|
||||||
|
)
|
||||||
|
|
||||||
model_response = together_ai.completion(
|
model_response = together_ai.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -997,7 +1031,8 @@ def completion(
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
api_key=together_ai_key,
|
api_key=together_ai_key,
|
||||||
logging_obj=logging
|
logging_obj=logging,
|
||||||
|
custom_prompt_dict=custom_prompt_dict
|
||||||
)
|
)
|
||||||
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
|
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
|
@ -1129,6 +1164,10 @@ def completion(
|
||||||
response = model_response
|
response = model_response
|
||||||
elif custom_llm_provider == "bedrock":
|
elif custom_llm_provider == "bedrock":
|
||||||
# boto3 reads keys from .env
|
# boto3 reads keys from .env
|
||||||
|
custom_prompt_dict = (
|
||||||
|
custom_prompt_dict
|
||||||
|
or litellm.custom_prompt_dict
|
||||||
|
)
|
||||||
model_response = bedrock.completion(
|
model_response = bedrock.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1182,9 +1221,13 @@ def completion(
|
||||||
"http://localhost:11434"
|
"http://localhost:11434"
|
||||||
|
|
||||||
)
|
)
|
||||||
if model in litellm.custom_prompt_dict:
|
custom_prompt_dict = (
|
||||||
|
custom_prompt_dict
|
||||||
|
or litellm.custom_prompt_dict
|
||||||
|
)
|
||||||
|
if model in custom_prompt_dict:
|
||||||
# check if the model has a registered custom prompt
|
# check if the model has a registered custom prompt
|
||||||
model_prompt_details = litellm.custom_prompt_dict[model]
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
prompt = custom_prompt(
|
prompt = custom_prompt(
|
||||||
role_dict=model_prompt_details["roles"],
|
role_dict=model_prompt_details["roles"],
|
||||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
|
@ -1196,7 +1239,7 @@ def completion(
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.pre_call(
|
logging.pre_call(
|
||||||
input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": litellm.custom_prompt_dict}
|
input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": custom_prompt_dict}
|
||||||
)
|
)
|
||||||
if kwargs.get('acompletion', False) == True:
|
if kwargs.get('acompletion', False) == True:
|
||||||
if optional_params.get("stream", False) == True:
|
if optional_params.get("stream", False) == True:
|
||||||
|
|
|
@ -431,7 +431,7 @@ def test_completion_text_openai():
|
||||||
print(response)
|
print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
test_completion_text_openai()
|
# test_completion_text_openai()
|
||||||
|
|
||||||
def test_completion_openai_with_optional_params():
|
def test_completion_openai_with_optional_params():
|
||||||
try:
|
try:
|
||||||
|
@ -837,18 +837,16 @@ def test_completion_together_ai():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# test_completion_together_ai()
|
# test_completion_together_ai()
|
||||||
# def test_customprompt_together_ai():
|
def test_customprompt_together_ai():
|
||||||
# try:
|
try:
|
||||||
# litellm.register_prompt_template(
|
litellm.set_verbose = True
|
||||||
# model="OpenAssistant/llama2-70b-oasst-sft-v10",
|
response = completion(model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10", messages=messages,
|
||||||
# roles={"system":"<|im_start|>system", "assistant":"<|im_start|>assistant", "user":"<|im_start|>user"}, # tell LiteLLM how you want to map the openai messages to this model
|
roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}})
|
||||||
# pre_message_sep= "\n",
|
print(response)
|
||||||
# post_message_sep= "\n"
|
except Exception as e:
|
||||||
# )
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# response = completion(model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10", messages=messages)
|
|
||||||
# print(response)
|
test_customprompt_together_ai()
|
||||||
# except Exception as e:
|
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
def test_completion_sagemaker():
|
def test_completion_sagemaker():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1997,7 +1997,10 @@ def validate_environment(model: Optional[str]=None) -> dict:
|
||||||
if model is None:
|
if model is None:
|
||||||
return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys}
|
return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys}
|
||||||
## EXTRACT LLM PROVIDER - if model name provided
|
## EXTRACT LLM PROVIDER - if model name provided
|
||||||
custom_llm_provider = get_llm_provider(model=model)
|
try:
|
||||||
|
custom_llm_provider = get_llm_provider(model=model)
|
||||||
|
except:
|
||||||
|
custom_llm_provider = None
|
||||||
# # check if llm provider part of model name
|
# # check if llm provider part of model name
|
||||||
# if model.split("/",1)[0] in litellm.provider_list:
|
# if model.split("/",1)[0] in litellm.provider_list:
|
||||||
# custom_llm_provider = model.split("/", 1)[0]
|
# custom_llm_provider = model.split("/", 1)[0]
|
||||||
|
|
|
@ -21,6 +21,9 @@ ANTHROPIC_API_KEY = ""
|
||||||
|
|
||||||
COHERE_API_KEY = ""
|
COHERE_API_KEY = ""
|
||||||
|
|
||||||
|
## CONFIG FILE ##
|
||||||
|
# CONFIG_FILE_PATH = "" # uncomment to point to config file
|
||||||
|
|
||||||
## LOGGING ##
|
## LOGGING ##
|
||||||
|
|
||||||
SET_VERBOSE = "False" # set to 'True' to see detailed input/output logs
|
SET_VERBOSE = "False" # set to 'True' to see detailed input/output logs
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
import litellm, os, traceback
|
import os, traceback
|
||||||
from fastapi import FastAPI, Request, HTTPException
|
from fastapi import FastAPI, Request, HTTPException
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from fastapi.responses import StreamingResponse, FileResponse
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import json
|
import json, sys
|
||||||
import os
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
# sys.path.insert(
|
||||||
|
# 0, os.path.abspath("../")
|
||||||
|
# ) # Adds the parent directory to the system path - for litellm local dev
|
||||||
|
import litellm
|
||||||
try:
|
try:
|
||||||
from utils import set_callbacks, load_router_config, print_verbose
|
from utils import set_callbacks, load_router_config, print_verbose
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -31,6 +34,7 @@ llm_model_list: Optional[list] = None
|
||||||
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
|
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
|
||||||
|
|
||||||
if "CONFIG_FILE_PATH" in os.environ:
|
if "CONFIG_FILE_PATH" in os.environ:
|
||||||
|
print(f"CONFIG FILE DETECTED")
|
||||||
llm_router, llm_model_list = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
|
llm_router, llm_model_list = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
|
||||||
else:
|
else:
|
||||||
llm_router, llm_model_list = load_router_config(router=llm_router)
|
llm_router, llm_model_list = load_router_config(router=llm_router)
|
||||||
|
@ -104,7 +108,6 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
## CHECK KEYS ##
|
## CHECK KEYS ##
|
||||||
# default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
|
# default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
|
||||||
env_validation = litellm.validate_environment(model=data["model"])
|
env_validation = litellm.validate_environment(model=data["model"])
|
||||||
print(f"request headers: {request.headers}")
|
|
||||||
if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
|
if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
|
||||||
if "authorization" in request.headers:
|
if "authorization" in request.headers:
|
||||||
api_key = request.headers.get("authorization")
|
api_key = request.headers.get("authorization")
|
||||||
|
@ -125,7 +128,6 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
for key, value in m["litellm_params"].items():
|
for key, value in m["litellm_params"].items():
|
||||||
data[key] = value
|
data[key] = value
|
||||||
break
|
break
|
||||||
print(f"data going into litellm completion: {data}")
|
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
**data
|
**data
|
||||||
)
|
)
|
||||||
|
@ -134,6 +136,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_traceback = traceback.format_exc()
|
error_traceback = traceback.format_exc()
|
||||||
|
print(f"{error_traceback}")
|
||||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||||
return {"error": error_msg}
|
return {"error": error_msg}
|
||||||
# raise HTTPException(status_code=500, detail=error_msg)
|
# raise HTTPException(status_code=500, detail=error_msg)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue