fix(factory.py): fixing llama-2 non-chat models prompt templating

This commit is contained in:
Krrish Dholakia 2023-11-07 21:33:46 -08:00
parent ce27e08e7d
commit e7735274de
6 changed files with 118 additions and 97 deletions

View file

@ -1,87 +0,0 @@
{
"20231014160921359878": {
"pre_api_call": {
"model": "codellama/CodeLlama-7b-Instruct-hf",
"messages": [
{
"role": "user",
"content": "hey"
}
],
"optional_params": {
"temperature": 0.5,
"stream": true,
"max_new_tokens": 1024,
"details": true,
"return_full_text": false
},
"litellm_params": {
"return_async": false,
"api_key": null,
"force_timeout": 600,
"logger_fn": null,
"verbose": false,
"custom_llm_provider": "huggingface",
"api_base": "https://app.baseten.co/models/pP8JeaB/predict",
"litellm_call_id": "d75891a0-d567-470a-a6cd-137e698da092",
"model_alias_map": {},
"completion_call_id": null,
"metadata": null,
"stream_response": {}
},
"input": "<s>[INST] hey [/INST]\n",
"api_key": "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",
"additional_args": {
"complete_input_dict": {
"inputs": "<s>[INST] hey [/INST]\n",
"parameters": {
"temperature": 0.5,
"stream": true,
"max_new_tokens": 1024,
"details": true,
"return_full_text": false
},
"stream": true
},
"task": "text-generation-inference",
"headers": {
"Authorization": "Api-Key SQqH1uZg.SSN79Bq997k4TRdzW9HBCghx9KyL0EJA"
}
},
"log_event_type": "pre_api_call"
},
"post_api_call": {
"model": "codellama/CodeLlama-7b-Instruct-hf",
"messages": [
{
"role": "user",
"content": "hey"
}
],
"optional_params": {
"temperature": 0.5,
"stream": true,
"max_new_tokens": 1024,
"details": true,
"return_full_text": false
},
"litellm_params": {
"return_async": false,
"api_key": null,
"force_timeout": 600,
"logger_fn": null,
"verbose": false,
"custom_llm_provider": "huggingface",
"api_base": "https://app.baseten.co/models/pP8JeaB/predict",
"litellm_call_id": "d75891a0-d567-470a-a6cd-137e698da092",
"model_alias_map": {},
"completion_call_id": null,
"metadata": null,
"stream_response": {}
},
"input": null,
"api_key": null,
"additional_args": {},
"log_event_type": "post_api_call",
"original_response": "<class 'generator'>",
"end_time":

View file

View file

@ -1,8 +0,0 @@
{
"Oct-12-2023": {
"claude-2": {
"cost": 0.02365918,
"num_requests": 1
}
}
}

View file

@ -120,6 +120,7 @@ config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.getenv( user_config_path = os.getenv(
"LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename) "LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)
) )
experimental = False
#### GLOBAL VARIABLES #### #### GLOBAL VARIABLES ####
llm_router: Optional[litellm.Router] = None llm_router: Optional[litellm.Router] = None
llm_model_list: Optional[list] = None llm_model_list: Optional[list] = None
@ -354,7 +355,7 @@ def initialize(
save, save,
config config
): ):
global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, llm_model_list, llm_router, server_settings global user_model, user_api_base, user_debug, user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, server_settings
generate_feedback_box() generate_feedback_box()
user_model = model user_model = model
user_debug = debug user_debug = debug
@ -393,6 +394,8 @@ def initialize(
dynamic_config["general"]["max_budget"] = max_budget dynamic_config["general"]["max_budget"] = max_budget
if debug: # litellm-specific param if debug: # litellm-specific param
litellm.set_verbose = True litellm.set_verbose = True
if experimental:
pass
if save: if save:
save_params_to_config(dynamic_config) save_params_to_config(dynamic_config)
with open(user_config_path) as f: with open(user_config_path) as f:
@ -537,6 +540,22 @@ async def chat_completion(request: Request, model: Optional[str] = None):
return {"error": error_msg} return {"error": error_msg}
@router.post("/router/chat/completions")
async def router_completion(request: Request):
try:
body = await request.body()
body_str = body.decode()
try:
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
return {"data": data}
except Exception as e:
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
return {"error": error_msg}
@router.get("/ollama_logs") @router.get("/ollama_logs")
async def retrieve_server_log(request: Request): async def retrieve_server_log(request: Request):
filepath = os.path.expanduser("~/.ollama/logs/server.log") filepath = os.path.expanduser("~/.ollama/logs/server.log")

View file

@ -1,3 +1,2 @@
#!/bin/bash #!/bin/bash
python3 proxy_cli.py --config -f ../../secrets_template.toml
python3 proxy_cli.py python3 proxy_cli.py

98
litellm/proxy/utils.py Normal file
View file

@ -0,0 +1,98 @@
# import threading, time, litellm
# import concurrent.futures
# """
# v1:
# 1. `--experimental_async` starts 2 background threads:
# - 1. to check the redis queue:
# - if job available
# - it dequeues as many jobs as healthy endpoints
# - calls llm api -> saves response in redis cache
# - 2. to check the llm apis:
# - check if endpoints are healthy (unhealthy = 4xx / 5xx call or >1min. queue)
# - which one is least busy
# 2. /router/chat/completions: receives request -> adds to redis queue -> returns {run_id, started_at, request_obj}
# 3. /router/chat/completions/runs/{run_id}: returns {status: _, [optional] response_obj: _}
# """
# def _start_health_check_thread():
# """
# Starts a separate thread to perform health checks periodically.
# """
# health_check_thread = threading.Thread(target=_perform_health_checks, daemon=True)
# health_check_thread.start()
# llm_call_thread = threading.Thread(target=_llm_call_thread, daemon=True)
# llm_call_thread.start()
# def _llm_call_thread():
# """
# Periodically performs job checks on the redis queue.
# If available, make llm api calls.
# Write result to redis cache (1 min ttl)
# """
# with concurrent.futures.ThreadPoolExecutor() as executor:
# while True:
# job_checks = _job_check()
# future_to_job = {executor.submit(_llm_api_call, job): job for job in job_checks}
# for future in concurrent.futures.as_completed(future_to_job):
# job = future_to_job[future]
# try:
# result = future.result()
# except Exception as exc:
# print(f'{job} generated an exception: {exc}')
# else:
# _write_to_cache(job, result, ttl=1*60)
# time.sleep(1) # sleep 1 second to avoid overloading the server
# def _perform_health_checks():
# """
# Periodically performs health checks on the servers.
# Updates the list of healthy servers accordingly.
# """
# while True:
# healthy_deployments = _health_check()
# # Adjust the time interval based on your needs
# time.sleep(15)
# def _job_check():
# """
# Periodically performs job checks on the redis queue.
# Returns the list of available jobs - len(available_jobs) == len(healthy_endpoints),
# e.g. don't dequeue a gpt-3.5-turbo job if there's no healthy deployments left
# """
# pass
# def _llm_api_call(**data):
# """
# Makes the litellm.completion() call with 3 retries
# """
# return litellm.completion(num_retries=3, **data)
# def _write_to_cache():
# """
# Writes the result to a redis cache in the form (key:job_id, value: <response_object>)
# """
# pass
# def _health_check():
# """
# Performs a health check on the deployments
# Returns the list of healthy deployments
# """
# healthy_deployments = []
# for deployment in model_list:
# litellm_args = deployment["litellm_params"]
# try:
# start_time = time.time()
# litellm.completion(messages=[{"role": "user", "content": ""}], max_tokens=1, **litellm_args) # hit the server with a blank message to see how long it takes to respond
# end_time = time.time()
# response_time = end_time - start_time
# logging.debug(f"response_time: {response_time}")
# healthy_deployments.append((deployment, response_time))
# healthy_deployments.sort(key=lambda x: x[1])
# except Exception as e:
# pass
# return healthy_deployments