forked from phoenix/litellm-mirror
build(litellm_server/utils.py): add support for general settings + num retries as a module variable
This commit is contained in:
parent
3f1b4c0759
commit
e3a1c58dd9
7 changed files with 52 additions and 33 deletions
|
@ -46,6 +46,7 @@ add_function_to_prompt: bool = False # if function calling not supported by api,
|
|||
client_session: Optional[requests.Session] = None
|
||||
model_fallbacks: Optional[List] = None
|
||||
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
num_retries: Optional[int] = None
|
||||
#############################################
|
||||
|
||||
def get_model_cost_map(url: str):
|
||||
|
|
|
@ -267,7 +267,6 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
|
|||
bos_open = False
|
||||
|
||||
prompt += final_prompt_value
|
||||
print(f"COMPLETE PROMPT: {prompt}")
|
||||
return prompt
|
||||
|
||||
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
|
||||
|
|
|
@ -270,6 +270,7 @@ def completion(
|
|||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
if mock_response:
|
||||
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
|
||||
|
||||
try:
|
||||
logging = litellm_logging_obj
|
||||
fallbacks = (
|
||||
|
|
|
@ -826,7 +826,12 @@ def client(original_function):
|
|||
except Exception as e:
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.completion.value:
|
||||
num_retries = kwargs.get("num_retries", None)
|
||||
num_retries = (
|
||||
kwargs.get("num_retries", None)
|
||||
or litellm.num_retries
|
||||
or None
|
||||
)
|
||||
litellm.num_retries = None # set retries to None to prevent infinite loops
|
||||
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", {})
|
||||
|
||||
if num_retries:
|
||||
|
|
|
@ -7,4 +7,4 @@ RUN pip install -r requirements.txt
|
|||
|
||||
EXPOSE $PORT
|
||||
|
||||
CMD exec uvicorn main:app --host 0.0.0.0 --port $PORT
|
||||
CMD exec uvicorn main:app --host 0.0.0.0 --port $PORT --workers 10
|
|
@ -5,10 +5,11 @@ from fastapi.responses import StreamingResponse, FileResponse
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import json, sys
|
||||
from typing import Optional
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../")
|
||||
# ) # Adds the parent directory to the system path - for litellm local dev
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../")
|
||||
) # Adds the parent directory to the system path - for litellm local dev
|
||||
import litellm
|
||||
print(f"litellm: {litellm}")
|
||||
try:
|
||||
from utils import set_callbacks, load_router_config, print_verbose
|
||||
except ImportError:
|
||||
|
@ -30,14 +31,15 @@ app.add_middleware(
|
|||
#### GLOBAL VARIABLES ####
|
||||
llm_router: Optional[litellm.Router] = None
|
||||
llm_model_list: Optional[list] = None
|
||||
server_settings: Optional[dict] = None
|
||||
|
||||
set_callbacks() # sets litellm callbacks for logging if they exist in the environment
|
||||
|
||||
if "CONFIG_FILE_PATH" in os.environ:
|
||||
print(f"CONFIG FILE DETECTED")
|
||||
llm_router, llm_model_list = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
|
||||
llm_router, llm_model_list, server_settings = load_router_config(router=llm_router, config_file_path=os.getenv("CONFIG_FILE_PATH"))
|
||||
else:
|
||||
llm_router, llm_model_list = load_router_config(router=llm_router)
|
||||
llm_router, llm_model_list, server_settings = load_router_config(router=llm_router)
|
||||
#### API ENDPOINTS ####
|
||||
@router.get("/v1/models")
|
||||
@router.get("/models") # if project requires model list
|
||||
|
@ -100,27 +102,31 @@ async def embedding(request: Request):
|
|||
@router.post("/chat/completions")
|
||||
@router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint
|
||||
async def chat_completion(request: Request, model: Optional[str] = None):
|
||||
global llm_model_list
|
||||
global llm_model_list, server_settings
|
||||
try:
|
||||
data = await request.json()
|
||||
if model:
|
||||
data["model"] = model
|
||||
print(f"data: {data}")
|
||||
data["model"] = (
|
||||
server_settings.get("completion_model", None) # server default
|
||||
or model # model passed in url
|
||||
or data["model"] # default passed in
|
||||
)
|
||||
## CHECK KEYS ##
|
||||
# default to always using the "ENV" variables, only if AUTH_STRATEGY==DYNAMIC then reads headers
|
||||
env_validation = litellm.validate_environment(model=data["model"])
|
||||
if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
|
||||
if "authorization" in request.headers:
|
||||
api_key = request.headers.get("authorization")
|
||||
elif "api-key" in request.headers:
|
||||
api_key = request.headers.get("api-key")
|
||||
print(f"api_key in headers: {api_key}")
|
||||
if " " in api_key:
|
||||
api_key = api_key.split(" ")[1]
|
||||
print(f"api_key split: {api_key}")
|
||||
if len(api_key) > 0:
|
||||
api_key = api_key
|
||||
data["api_key"] = api_key
|
||||
print(f"api_key in data: {api_key}")
|
||||
# env_validation = litellm.validate_environment(model=data["model"])
|
||||
# if (env_validation['keys_in_environment'] is False or os.getenv("AUTH_STRATEGY", None) == "DYNAMIC") and ("authorization" in request.headers or "api-key" in request.headers): # if users pass LLM api keys as part of header
|
||||
# if "authorization" in request.headers:
|
||||
# api_key = request.headers.get("authorization")
|
||||
# elif "api-key" in request.headers:
|
||||
# api_key = request.headers.get("api-key")
|
||||
# print(f"api_key in headers: {api_key}")
|
||||
# if " " in api_key:
|
||||
# api_key = api_key.split(" ")[1]
|
||||
# print(f"api_key split: {api_key}")
|
||||
# if len(api_key) > 0:
|
||||
# api_key = api_key
|
||||
# data["api_key"] = api_key
|
||||
# print(f"api_key in data: {api_key}")
|
||||
## CHECK CONFIG ##
|
||||
if llm_model_list and data["model"] in [m["model_name"] for m in llm_model_list]:
|
||||
for m in llm_model_list:
|
||||
|
@ -133,13 +139,14 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
)
|
||||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||
print(f"response: {response}")
|
||||
return response
|
||||
except Exception as e:
|
||||
error_traceback = traceback.format_exc()
|
||||
print(f"{error_traceback}")
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
return {"error": error_msg}
|
||||
# raise HTTPException(status_code=500, detail=error_msg)
|
||||
# return {"error": error_msg}
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
@router.post("/router/completions")
|
||||
async def router_completion(request: Request):
|
||||
|
|
|
@ -43,16 +43,17 @@ def set_callbacks():
|
|||
|
||||
## CACHING
|
||||
### REDIS
|
||||
if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0:
|
||||
from litellm.caching import Cache
|
||||
litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
||||
print("\033[92mLiteLLM: Switched on Redis caching\033[0m")
|
||||
# if len(os.getenv("REDIS_HOST", "")) > 0 and len(os.getenv("REDIS_PORT", "")) > 0 and len(os.getenv("REDIS_PASSWORD", "")) > 0:
|
||||
# print(f"redis host: {os.getenv('REDIS_HOST')}; redis port: {os.getenv('REDIS_PORT')}; password: {os.getenv('REDIS_PASSWORD')}")
|
||||
# from litellm.caching import Cache
|
||||
# litellm.cache = Cache(type="redis", host=os.getenv("REDIS_HOST"), port=os.getenv("REDIS_PORT"), password=os.getenv("REDIS_PASSWORD"))
|
||||
# print("\033[92mLiteLLM: Switched on Redis caching\033[0m")
|
||||
|
||||
|
||||
|
||||
def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]='/app/config.yaml'):
|
||||
config = {}
|
||||
|
||||
server_settings = {}
|
||||
try:
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path, 'r') as file:
|
||||
|
@ -62,6 +63,11 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: Optio
|
|||
except:
|
||||
pass
|
||||
|
||||
## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral')
|
||||
server_settings = config.get("server_settings", None)
|
||||
if server_settings:
|
||||
server_settings = server_settings
|
||||
|
||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||
litellm_settings = config.get('litellm_settings', None)
|
||||
if litellm_settings:
|
||||
|
@ -79,4 +85,4 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: Optio
|
|||
for key, value in environment_variables.items():
|
||||
os.environ[key] = value
|
||||
|
||||
return router, model_list
|
||||
return router, model_list, server_settings
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue