From 512a1637eb3f46a3b5048af1fc07546d02289dd3 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 2 Nov 2023 16:23:51 -0700 Subject: [PATCH] feat(completion()): enable setting prompt templates via completion() --- .gitignore | 1 + litellm/llms/huggingface_restapi.py | 16 ++++--- litellm/llms/prompt_templates/factory.py | 1 + litellm/llms/together_ai.py | 13 +++--- litellm/main.py | 55 +++++++++++++++++++++--- litellm/tests/test_completion.py | 24 +++++------ litellm/utils.py | 5 ++- litellm_server/.env.template | 3 ++ litellm_server/main.py | 13 +++--- 9 files changed, 94 insertions(+), 37 deletions(-) diff --git a/.gitignore b/.gitignore index 313241e4c..e3e1bee69 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ litellm/proxy/litellm_secrets.toml litellm/proxy/api_log.json .idea/ router_config.yaml +litellm_server/config.yaml diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index e62c8e335..c7e5be915 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -193,9 +193,9 @@ def completion( # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), messages=messages ) else: @@ -213,10 +213,12 @@ def completion( # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages + role_dict=model_prompt_details.get("roles", {}), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + bos_token=model_prompt_details.get("bos_token", ""), + eos_token=model_prompt_details.get("eos_token", ""), + messages=messages, ) else: prompt = prompt_factory(model=model, messages=messages) diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 959b8759f..f23691612 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -267,6 +267,7 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="", bos_open = False prompt += final_prompt_value + print(f"COMPLETE PROMPT: {prompt}") return prompt def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None): diff --git a/litellm/llms/together_ai.py b/litellm/llms/together_ai.py index f49cd13b7..9d18bfa7c 100644 --- a/litellm/llms/together_ai.py +++ b/litellm/llms/together_ai.py @@ -99,15 +99,18 @@ def completion( if k not in optional_params: # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v + print(f"CUSTOM PROMPT DICT: {custom_prompt_dict}; model: {model}") if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages - ) + role_dict=model_prompt_details.get("roles", {}), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + bos_token=model_prompt_details.get("bos_token", ""), + eos_token=model_prompt_details.get("eos_token", ""), + messages=messages, + ) else: prompt = prompt_factory(model=model, messages=messages) diff --git a/litellm/main.py b/litellm/main.py index b4ecc67d4..03a96b58d 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -257,9 +257,15 @@ def completion( headers = kwargs.get("headers", None) num_retries = kwargs.get("num_retries", None) context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) + ### CUSTOM PROMPT TEMPLATE ### + initial_prompt_value = kwargs.get("intial_prompt_value", None) + roles = kwargs.get("roles", None) + final_prompt_value = kwargs.get("final_prompt_value", None) + bos_token = kwargs.get("bos_token", None) + eos_token = kwargs.get("eos_token", None) ######## end of unpacking kwargs ########### openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"] - litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict"] + litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token"] default_params = openai_params + litellm_params non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider if mock_response: @@ -280,6 +286,7 @@ def completion( model = litellm.model_alias_map[ model ] # update the model to the actual value if an alias has been passed in + model_response = ModelResponse() if kwargs.get('azure', False) == True: # don't remove flag check, to remain backwards compatible for repos like Codium @@ -288,6 +295,19 @@ def completion( model=deployment_id custom_llm_provider="azure" model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) + custom_prompt_dict = None + if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token: + custom_prompt_dict = {model: {}} + if initial_prompt_value: + custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value + if roles: + custom_prompt_dict[model]["roles"] = roles + if final_prompt_value: + custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value + if bos_token: + custom_prompt_dict[model]["bos_token"] = bos_token + if eos_token: + custom_prompt_dict[model]["eos_token"] = eos_token model_api_key = get_api_key(llm_provider=custom_llm_provider, dynamic_api_key=api_key) # get the api key from the environment if required for the model if model_api_key and "sk-litellm" in model_api_key: api_base = "https://proxy.litellm.ai" @@ -646,6 +666,10 @@ def completion( or get_secret("ANTHROPIC_API_BASE") or "https://api.anthropic.com/v1/complete" ) + custom_prompt_dict = ( + custom_prompt_dict + or litellm.custom_prompt_dict + ) model_response = anthropic.completion( model=model, messages=messages, @@ -867,6 +891,11 @@ def completion( headers or litellm.headers ) + + custom_prompt_dict = ( + custom_prompt_dict + or litellm.custom_prompt_dict + ) model_response = huggingface_restapi.completion( model=model, messages=messages, @@ -880,7 +909,7 @@ def completion( encoding=encoding, api_key=huggingface_key, logging_obj=logging, - custom_prompt_dict=litellm.custom_prompt_dict + custom_prompt_dict=custom_prompt_dict ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, @@ -985,6 +1014,11 @@ def completion( or get_secret("TOGETHERAI_API_BASE") or "https://api.together.xyz/inference" ) + + custom_prompt_dict = ( + custom_prompt_dict + or litellm.custom_prompt_dict + ) model_response = together_ai.completion( model=model, @@ -997,7 +1031,8 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=together_ai_key, - logging_obj=logging + logging_obj=logging, + custom_prompt_dict=custom_prompt_dict ) if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: # don't try to access stream object, @@ -1129,6 +1164,10 @@ def completion( response = model_response elif custom_llm_provider == "bedrock": # boto3 reads keys from .env + custom_prompt_dict = ( + custom_prompt_dict + or litellm.custom_prompt_dict + ) model_response = bedrock.completion( model=model, messages=messages, @@ -1182,9 +1221,13 @@ def completion( "http://localhost:11434" ) - if model in litellm.custom_prompt_dict: + custom_prompt_dict = ( + custom_prompt_dict + or litellm.custom_prompt_dict + ) + if model in custom_prompt_dict: # check if the model has a registered custom prompt - model_prompt_details = litellm.custom_prompt_dict[model] + model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( role_dict=model_prompt_details["roles"], initial_prompt_value=model_prompt_details["initial_prompt_value"], @@ -1196,7 +1239,7 @@ def completion( ## LOGGING logging.pre_call( - input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": litellm.custom_prompt_dict} + input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": custom_prompt_dict} ) if kwargs.get('acompletion', False) == True: if optional_params.get("stream", False) == True: diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index c6d92b023..e1da13c79 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -431,7 +431,7 @@ def test_completion_text_openai(): print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_text_openai() +# test_completion_text_openai() def test_completion_openai_with_optional_params(): try: @@ -837,18 +837,16 @@ def test_completion_together_ai(): pytest.fail(f"Error occurred: {e}") # test_completion_together_ai() -# def test_customprompt_together_ai(): -# try: -# litellm.register_prompt_template( -# model="OpenAssistant/llama2-70b-oasst-sft-v10", -# roles={"system":"<|im_start|>system", "assistant":"<|im_start|>assistant", "user":"<|im_start|>user"}, # tell LiteLLM how you want to map the openai messages to this model -# pre_message_sep= "\n", -# post_message_sep= "\n" -# ) -# response = completion(model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10", messages=messages) -# print(response) -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") +def test_customprompt_together_ai(): + try: + litellm.set_verbose = True + response = completion(model="together_ai/OpenAssistant/llama2-70b-oasst-sft-v10", messages=messages, + roles={"system":{"pre_message":"<|im_start|>system\n", "post_message":"<|im_end|>"}, "assistant":{"pre_message":"<|im_start|>assistant\n","post_message":"<|im_end|>"}, "user":{"pre_message":"<|im_start|>user\n","post_message":"<|im_end|>"}}) + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +test_customprompt_together_ai() def test_completion_sagemaker(): try: diff --git a/litellm/utils.py b/litellm/utils.py index f8069e9da..3b1a640e7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1997,7 +1997,10 @@ def validate_environment(model: Optional[str]=None) -> dict: if model is None: return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys} ## EXTRACT LLM PROVIDER - if model name provided - custom_llm_provider = get_llm_provider(model=model) + try: + custom_llm_provider = get_llm_provider(model=model) + except: + custom_llm_provider = None # # check if llm provider part of model name # if model.split("/",1)[0] in litellm.provider_list: # custom_llm_provider = model.split("/", 1)[0] diff --git a/litellm_server/.env.template b/litellm_server/.env.template index 280c38912..a87ae9cf3 100644 --- a/litellm_server/.env.template +++ b/litellm_server/.env.template @@ -21,6 +21,9 @@ ANTHROPIC_API_KEY = "" COHERE_API_KEY = "" +## CONFIG FILE ## +# CONFIG_FILE_PATH = "" # uncomment to point to config file + ## LOGGING ## SET_VERBOSE = "False" # set to 'True' to see detailed input/output logs diff --git a/litellm_server/main.py b/litellm_server/main.py index 6eb158361..c69bda620 100644 --- a/litellm_server/main.py +++ b/litellm_server/main.py @@ -1,11 +1,14 @@ -import litellm, os, traceback +import os, traceback from fastapi import FastAPI, Request, HTTPException from fastapi.routing import APIRouter from fastapi.responses import StreamingResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware -import json -import os +import json, sys from typing import Optional +# sys.path.insert( +# 0, os.path.abspath("../") +# ) # Adds the parent directory to the system path - for litellm local dev +import litellm try: from utils import set_callbacks, load_router_config, print_verbose except ImportError: @@ -31,6 +34,7 @@ llm_model_list: Optional[list] = None set_callbacks() # sets litellm callbacks for logging if they exist in the environment if "CONFIG_FILE_PATH" in os.environ: + print(f"CONFIG FILE DETECTED") llm_router, llm_model_list = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH")) else: llm_router, llm_model_list = load_router_config(router=llm_router) @@ -104,7 +108,6 @@ async def chat_completion(request: Request, model: Optional[str] = None): ## CHECK KEYS ## # default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers env_validation = litellm.validate_environment(model=data["model"]) - print(f"request headers: {request.headers}") if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header if "authorization" in request.headers: api_key = request.headers.get("authorization") @@ -125,7 +128,6 @@ async def chat_completion(request: Request, model: Optional[str] = None): for key, value in m["litellm_params"].items(): data[key] = value break - print(f"data going into litellm completion: {data}") response = litellm.completion( **data ) @@ -134,6 +136,7 @@ async def chat_completion(request: Request, model: Optional[str] = None): return response except Exception as e: error_traceback = traceback.format_exc() + print(f"{error_traceback}") error_msg = f"{str(e)}\n\n{error_traceback}" return {"error": error_msg} # raise HTTPException(status_code=500, detail=error_msg)