fix(add-custom-success-callback-for-streaming): add custom success callback for streaming

This commit is contained in:
Krrish Dholakia 2023-10-06 15:01:50 -07:00
parent 868c1c594f
commit 7e34736a38
8 changed files with 89 additions and 20 deletions

View file

@ -1,6 +1,7 @@
import requests, traceback
import json
from jinja2 import Template, exceptions, Environment, meta
from typing import Optional
def default_pt(messages):
return " ".join(message["content"] for message in messages)
@ -26,6 +27,25 @@ def llama_2_chat_pt(messages):
)
return prompt
def ollama_pt(messages): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
prompt = custom_prompt(
role_dict={
"system": {
"pre_message": "### System:\n",
"post_message": "\n"
},
"user": {
"pre_message": "### User:\n",
"post_message": "\n",
},
"assistant": {
"pre_message": "### Response:\n",
"post_message": "\n",
}
},
final_prompt_value="### Response:"
)
def mistral_instruct_pt(messages):
prompt = custom_prompt(
initial_prompt_value="<s>",
@ -190,9 +210,13 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
prompt += final_prompt_value
return prompt
def prompt_factory(model: str, messages: list):
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
original_model_name = model
model = model.lower()
if custom_llm_provider == "ollama":
return ollama_pt(messages=messages)
try:
if "meta-llama/llama-2" in model:
if "chat" in model:

View file

@ -961,7 +961,7 @@ def completion(
messages=messages
)
else:
prompt = prompt_factory(model=model, messages=messages)
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
## LOGGING
logging.pre_call(
@ -1410,6 +1410,7 @@ def text_completion(*args, **kwargs):
kwargs["messages"] = messages
kwargs.pop("prompt")
response = completion(*args, **kwargs) # assume the response is the openai response object
print(f"response: {response}")
formatted_response_obj = {
"id": response["id"],
"object": "text_completion",

View file

@ -61,7 +61,7 @@ def open_config():
@click.option('--max_tokens', default=None, help='Set max tokens for the model')
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
@click.option('--config', is_flag=True, help='Create and open .env file from .env.template')
@click.option('--test', default=None, help='proxy chat completions url to make a test request to')
@click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to')
@click.option('--local', is_flag=True, default=False, help='for local debugging')
def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, telemetry, config, test, local):
if config:
@ -82,10 +82,14 @@ def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, te
print(f"\033[32mLiteLLM: Test your URL using the following: \"litellm --test {url}\"\033[0m")
return
if test != None:
if test != False:
click.echo('LiteLLM: Making a test ChatCompletions request to your proxy')
import openai
openai.api_base = test
if test == True: # flag value set
api_base = "http://0.0.0.0:8000"
else:
api_base = test
openai.api_base = api_base
openai.api_key = "temp-key"
print(openai.api_base)
@ -107,7 +111,7 @@ def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, te
except:
raise ImportError("Uvicorn needs to be imported. Run - `pip install uvicorn`")
print(f"\033[32mLiteLLM: Deployed Proxy Locally\033[0m\n")
print(f"\033[32mLiteLLM: Test your URL using the following: \"litellm --test http://0.0.0.0:{port}\" [In a new terminal tab]\033[0m\n")
print(f"\033[32mLiteLLM: Test your local endpoint with: \"litellm --test\" [In a new terminal tab]\033[0m\n")
print(f"\033[32mLiteLLM: Deploy your proxy using the following: \"litellm --model claude-instant-1 --deploy\" Get an https://api.litellm.ai/chat/completions endpoint \033[0m\n")
uvicorn.run(app, host='0.0.0.0', port=port)

View file

@ -18,10 +18,12 @@ print()
import litellm
from fastapi import FastAPI, Request
from fastapi.routing import APIRouter
from fastapi.responses import StreamingResponse
import json
app = FastAPI()
router = APIRouter()
user_api_base = None
user_model = None
@ -109,14 +111,14 @@ def data_generator(response):
yield f"data: {json.dumps(chunk)}\n\n"
#### API ENDPOINTS ####
@app.get("/models") # if project requires model list
@router.get("/models") # if project requires model list
def model_list():
return dict(
data=[{"id": user_model, "object": "model", "created": 1677610602, "owned_by": "openai"}],
object="list",
)
@app.post("/{version}/completions")
@router.post("/completions")
async def completion(request: Request):
data = await request.json()
print_verbose(f"data passed in: {data}")
@ -149,7 +151,7 @@ async def completion(request: Request):
return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response
@app.post("/chat/completions")
@router.post("/chat/completions")
async def chat_completion(request: Request):
data = await request.json()
print_verbose(f"data passed in: {data}")
@ -186,4 +188,6 @@ async def chat_completion(request: Request):
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
print_verbose(f"response: {response}")
return response
return response
app.include_router(router)

View file

@ -709,7 +709,7 @@ def test_completion_sagemaker_stream():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_sagemaker_stream()
# test_completion_sagemaker_stream()
# test on openai completion call
def test_openai_text_completion_call():
@ -732,8 +732,16 @@ def test_openai_text_completion_call():
# # test on ai21 completion call
def ai21_completion_call():
try:
messages=[{
"role": "system",
"content": "You are an all-knowing oracle",
},
{
"role": "user",
"content": "What is the meaning of the Universe?"
}]
response = completion(
model="j2-ultra", messages=messages, stream=True
model="j2-ultra", messages=messages, stream=True, max_tokens=500
)
print(f"response: {response}")
has_finished = False
@ -1262,3 +1270,31 @@ def test_openai_streaming_and_function_calling():
raise e
# test_openai_streaming_and_function_calling()
import litellm
def test_success_callback_streaming():
def success_callback(kwargs, completion_response, start_time, end_time):
print(
{
"success": True,
"input": kwargs,
"output": completion_response,
"start_time": start_time,
"end_time": end_time,
}
)
litellm.success_callback = [success_callback]
messages = [{"role": "user", "content": "hello"}]
response = litellm.completion(model="gpt-3.5-turbo", messages=messages, stream=True)
print(response)
for chunk in response:
print(chunk["choices"][0])
test_success_callback_streaming()

View file

@ -456,6 +456,14 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
if callable(callback): # custom logger functions
customLogger.log_event(
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
)
except Exception as e:
print_verbose(
@ -2022,14 +2030,6 @@ def handle_success(args, kwargs, result, start_time, end_time):
litellm_call_id=kwargs["litellm_call_id"],
print_verbose=print_verbose,
)
elif callable(callback): # custom logger functions
customLogger.log_event(
kwargs=kwargs,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
)
except Exception as e:
# LOGGING
exception_logging(logger_fn=user_logger_fn, exception=e)