fix(add-custom-success-callback-for-streaming): add custom success callback for streaming

This commit is contained in:
Krrish Dholakia 2023-10-06 15:01:50 -07:00
parent 868c1c594f
commit 7e34736a38
8 changed files with 89 additions and 20 deletions

View file

@ -1,6 +1,7 @@
import requests, traceback import requests, traceback
import json import json
from jinja2 import Template, exceptions, Environment, meta from jinja2 import Template, exceptions, Environment, meta
from typing import Optional
def default_pt(messages): def default_pt(messages):
return " ".join(message["content"] for message in messages) return " ".join(message["content"] for message in messages)
@ -26,6 +27,25 @@ def llama_2_chat_pt(messages):
) )
return prompt return prompt
def ollama_pt(messages): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
prompt = custom_prompt(
role_dict={
"system": {
"pre_message": "### System:\n",
"post_message": "\n"
},
"user": {
"pre_message": "### User:\n",
"post_message": "\n",
},
"assistant": {
"pre_message": "### Response:\n",
"post_message": "\n",
}
},
final_prompt_value="### Response:"
)
def mistral_instruct_pt(messages): def mistral_instruct_pt(messages):
prompt = custom_prompt( prompt = custom_prompt(
initial_prompt_value="<s>", initial_prompt_value="<s>",
@ -190,9 +210,13 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
prompt += final_prompt_value prompt += final_prompt_value
return prompt return prompt
def prompt_factory(model: str, messages: list): def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
original_model_name = model original_model_name = model
model = model.lower() model = model.lower()
if custom_llm_provider == "ollama":
return ollama_pt(messages=messages)
try: try:
if "meta-llama/llama-2" in model: if "meta-llama/llama-2" in model:
if "chat" in model: if "chat" in model:

View file

@ -961,7 +961,7 @@ def completion(
messages=messages messages=messages
) )
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
## LOGGING ## LOGGING
logging.pre_call( logging.pre_call(
@ -1410,6 +1410,7 @@ def text_completion(*args, **kwargs):
kwargs["messages"] = messages kwargs["messages"] = messages
kwargs.pop("prompt") kwargs.pop("prompt")
response = completion(*args, **kwargs) # assume the response is the openai response object response = completion(*args, **kwargs) # assume the response is the openai response object
print(f"response: {response}")
formatted_response_obj = { formatted_response_obj = {
"id": response["id"], "id": response["id"],
"object": "text_completion", "object": "text_completion",

View file

@ -61,7 +61,7 @@ def open_config():
@click.option('--max_tokens', default=None, help='Set max tokens for the model') @click.option('--max_tokens', default=None, help='Set max tokens for the model')
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`') @click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
@click.option('--config', is_flag=True, help='Create and open .env file from .env.template') @click.option('--config', is_flag=True, help='Create and open .env file from .env.template')
@click.option('--test', default=None, help='proxy chat completions url to make a test request to') @click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to')
@click.option('--local', is_flag=True, default=False, help='for local debugging') @click.option('--local', is_flag=True, default=False, help='for local debugging')
def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, telemetry, config, test, local): def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, telemetry, config, test, local):
if config: if config:
@ -82,10 +82,14 @@ def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, te
print(f"\033[32mLiteLLM: Test your URL using the following: \"litellm --test {url}\"\033[0m") print(f"\033[32mLiteLLM: Test your URL using the following: \"litellm --test {url}\"\033[0m")
return return
if test != None: if test != False:
click.echo('LiteLLM: Making a test ChatCompletions request to your proxy') click.echo('LiteLLM: Making a test ChatCompletions request to your proxy')
import openai import openai
openai.api_base = test if test == True: # flag value set
api_base = "http://0.0.0.0:8000"
else:
api_base = test
openai.api_base = api_base
openai.api_key = "temp-key" openai.api_key = "temp-key"
print(openai.api_base) print(openai.api_base)
@ -107,7 +111,7 @@ def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, te
except: except:
raise ImportError("Uvicorn needs to be imported. Run - `pip install uvicorn`") raise ImportError("Uvicorn needs to be imported. Run - `pip install uvicorn`")
print(f"\033[32mLiteLLM: Deployed Proxy Locally\033[0m\n") print(f"\033[32mLiteLLM: Deployed Proxy Locally\033[0m\n")
print(f"\033[32mLiteLLM: Test your URL using the following: \"litellm --test http://0.0.0.0:{port}\" [In a new terminal tab]\033[0m\n") print(f"\033[32mLiteLLM: Test your local endpoint with: \"litellm --test\" [In a new terminal tab]\033[0m\n")
print(f"\033[32mLiteLLM: Deploy your proxy using the following: \"litellm --model claude-instant-1 --deploy\" Get an https://api.litellm.ai/chat/completions endpoint \033[0m\n") print(f"\033[32mLiteLLM: Deploy your proxy using the following: \"litellm --model claude-instant-1 --deploy\" Get an https://api.litellm.ai/chat/completions endpoint \033[0m\n")
uvicorn.run(app, host='0.0.0.0', port=port) uvicorn.run(app, host='0.0.0.0', port=port)

View file

@ -18,10 +18,12 @@ print()
import litellm import litellm
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.routing import APIRouter
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import json import json
app = FastAPI() app = FastAPI()
router = APIRouter()
user_api_base = None user_api_base = None
user_model = None user_model = None
@ -109,14 +111,14 @@ def data_generator(response):
yield f"data: {json.dumps(chunk)}\n\n" yield f"data: {json.dumps(chunk)}\n\n"
#### API ENDPOINTS #### #### API ENDPOINTS ####
@app.get("/models") # if project requires model list @router.get("/models") # if project requires model list
def model_list(): def model_list():
return dict( return dict(
data=[{"id": user_model, "object": "model", "created": 1677610602, "owned_by": "openai"}], data=[{"id": user_model, "object": "model", "created": 1677610602, "owned_by": "openai"}],
object="list", object="list",
) )
@app.post("/{version}/completions") @router.post("/completions")
async def completion(request: Request): async def completion(request: Request):
data = await request.json() data = await request.json()
print_verbose(f"data passed in: {data}") print_verbose(f"data passed in: {data}")
@ -149,7 +151,7 @@ async def completion(request: Request):
return StreamingResponse(data_generator(response), media_type='text/event-stream') return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response return response
@app.post("/chat/completions") @router.post("/chat/completions")
async def chat_completion(request: Request): async def chat_completion(request: Request):
data = await request.json() data = await request.json()
print_verbose(f"data passed in: {data}") print_verbose(f"data passed in: {data}")
@ -186,4 +188,6 @@ async def chat_completion(request: Request):
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream') return StreamingResponse(data_generator(response), media_type='text/event-stream')
print_verbose(f"response: {response}") print_verbose(f"response: {response}")
return response return response
app.include_router(router)

View file

@ -709,7 +709,7 @@ def test_completion_sagemaker_stream():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_sagemaker_stream() # test_completion_sagemaker_stream()
# test on openai completion call # test on openai completion call
def test_openai_text_completion_call(): def test_openai_text_completion_call():
@ -732,8 +732,16 @@ def test_openai_text_completion_call():
# # test on ai21 completion call # # test on ai21 completion call
def ai21_completion_call(): def ai21_completion_call():
try: try:
messages=[{
"role": "system",
"content": "You are an all-knowing oracle",
},
{
"role": "user",
"content": "What is the meaning of the Universe?"
}]
response = completion( response = completion(
model="j2-ultra", messages=messages, stream=True model="j2-ultra", messages=messages, stream=True, max_tokens=500
) )
print(f"response: {response}") print(f"response: {response}")
has_finished = False has_finished = False
@ -1262,3 +1270,31 @@ def test_openai_streaming_and_function_calling():
raise e raise e
# test_openai_streaming_and_function_calling() # test_openai_streaming_and_function_calling()
import litellm
def test_success_callback_streaming():
def success_callback(kwargs, completion_response, start_time, end_time):
print(
{
"success": True,
"input": kwargs,
"output": completion_response,
"start_time": start_time,
"end_time": end_time,
}
)
litellm.success_callback = [success_callback]
messages = [{"role": "user", "content": "hello"}]
response = litellm.completion(model="gpt-3.5-turbo", messages=messages, stream=True)
print(response)
for chunk in response:
print(chunk["choices"][0])
test_success_callback_streaming()

View file

@ -456,6 +456,14 @@ class Logging:
end_time=end_time, end_time=end_time,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if callable(callback): # custom logger functions
customLogger.log_event(
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
)
except Exception as e: except Exception as e:
print_verbose( print_verbose(
@ -2022,14 +2030,6 @@ def handle_success(args, kwargs, result, start_time, end_time):
litellm_call_id=kwargs["litellm_call_id"], litellm_call_id=kwargs["litellm_call_id"],
print_verbose=print_verbose, print_verbose=print_verbose,
) )
elif callable(callback): # custom logger functions
customLogger.log_event(
kwargs=kwargs,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
)
except Exception as e: except Exception as e:
# LOGGING # LOGGING
exception_logging(logger_fn=user_logger_fn, exception=e) exception_logging(logger_fn=user_logger_fn, exception=e)