mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(add-custom-success-callback-for-streaming): add custom success callback for streaming
This commit is contained in:
parent
868c1c594f
commit
7e34736a38
8 changed files with 89 additions and 20 deletions
Binary file not shown.
Binary file not shown.
|
@ -1,6 +1,7 @@
|
||||||
import requests, traceback
|
import requests, traceback
|
||||||
import json
|
import json
|
||||||
from jinja2 import Template, exceptions, Environment, meta
|
from jinja2 import Template, exceptions, Environment, meta
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
def default_pt(messages):
|
def default_pt(messages):
|
||||||
return " ".join(message["content"] for message in messages)
|
return " ".join(message["content"] for message in messages)
|
||||||
|
@ -26,6 +27,25 @@ def llama_2_chat_pt(messages):
|
||||||
)
|
)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
def ollama_pt(messages): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict={
|
||||||
|
"system": {
|
||||||
|
"pre_message": "### System:\n",
|
||||||
|
"post_message": "\n"
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"pre_message": "### User:\n",
|
||||||
|
"post_message": "\n",
|
||||||
|
},
|
||||||
|
"assistant": {
|
||||||
|
"pre_message": "### Response:\n",
|
||||||
|
"post_message": "\n",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
final_prompt_value="### Response:"
|
||||||
|
)
|
||||||
|
|
||||||
def mistral_instruct_pt(messages):
|
def mistral_instruct_pt(messages):
|
||||||
prompt = custom_prompt(
|
prompt = custom_prompt(
|
||||||
initial_prompt_value="<s>",
|
initial_prompt_value="<s>",
|
||||||
|
@ -190,9 +210,13 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
|
||||||
prompt += final_prompt_value
|
prompt += final_prompt_value
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def prompt_factory(model: str, messages: list):
|
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
|
||||||
original_model_name = model
|
original_model_name = model
|
||||||
model = model.lower()
|
model = model.lower()
|
||||||
|
|
||||||
|
if custom_llm_provider == "ollama":
|
||||||
|
return ollama_pt(messages=messages)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "meta-llama/llama-2" in model:
|
if "meta-llama/llama-2" in model:
|
||||||
if "chat" in model:
|
if "chat" in model:
|
||||||
|
|
|
@ -961,7 +961,7 @@ def completion(
|
||||||
messages=messages
|
messages=messages
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = prompt_factory(model=model, messages=messages)
|
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.pre_call(
|
logging.pre_call(
|
||||||
|
@ -1410,6 +1410,7 @@ def text_completion(*args, **kwargs):
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
kwargs.pop("prompt")
|
kwargs.pop("prompt")
|
||||||
response = completion(*args, **kwargs) # assume the response is the openai response object
|
response = completion(*args, **kwargs) # assume the response is the openai response object
|
||||||
|
print(f"response: {response}")
|
||||||
formatted_response_obj = {
|
formatted_response_obj = {
|
||||||
"id": response["id"],
|
"id": response["id"],
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
|
|
|
@ -61,7 +61,7 @@ def open_config():
|
||||||
@click.option('--max_tokens', default=None, help='Set max tokens for the model')
|
@click.option('--max_tokens', default=None, help='Set max tokens for the model')
|
||||||
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
|
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
|
||||||
@click.option('--config', is_flag=True, help='Create and open .env file from .env.template')
|
@click.option('--config', is_flag=True, help='Create and open .env file from .env.template')
|
||||||
@click.option('--test', default=None, help='proxy chat completions url to make a test request to')
|
@click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to')
|
||||||
@click.option('--local', is_flag=True, default=False, help='for local debugging')
|
@click.option('--local', is_flag=True, default=False, help='for local debugging')
|
||||||
def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, telemetry, config, test, local):
|
def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, telemetry, config, test, local):
|
||||||
if config:
|
if config:
|
||||||
|
@ -82,10 +82,14 @@ def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, te
|
||||||
|
|
||||||
print(f"\033[32mLiteLLM: Test your URL using the following: \"litellm --test {url}\"\033[0m")
|
print(f"\033[32mLiteLLM: Test your URL using the following: \"litellm --test {url}\"\033[0m")
|
||||||
return
|
return
|
||||||
if test != None:
|
if test != False:
|
||||||
click.echo('LiteLLM: Making a test ChatCompletions request to your proxy')
|
click.echo('LiteLLM: Making a test ChatCompletions request to your proxy')
|
||||||
import openai
|
import openai
|
||||||
openai.api_base = test
|
if test == True: # flag value set
|
||||||
|
api_base = "http://0.0.0.0:8000"
|
||||||
|
else:
|
||||||
|
api_base = test
|
||||||
|
openai.api_base = api_base
|
||||||
openai.api_key = "temp-key"
|
openai.api_key = "temp-key"
|
||||||
print(openai.api_base)
|
print(openai.api_base)
|
||||||
|
|
||||||
|
@ -107,7 +111,7 @@ def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, te
|
||||||
except:
|
except:
|
||||||
raise ImportError("Uvicorn needs to be imported. Run - `pip install uvicorn`")
|
raise ImportError("Uvicorn needs to be imported. Run - `pip install uvicorn`")
|
||||||
print(f"\033[32mLiteLLM: Deployed Proxy Locally\033[0m\n")
|
print(f"\033[32mLiteLLM: Deployed Proxy Locally\033[0m\n")
|
||||||
print(f"\033[32mLiteLLM: Test your URL using the following: \"litellm --test http://0.0.0.0:{port}\" [In a new terminal tab]\033[0m\n")
|
print(f"\033[32mLiteLLM: Test your local endpoint with: \"litellm --test\" [In a new terminal tab]\033[0m\n")
|
||||||
print(f"\033[32mLiteLLM: Deploy your proxy using the following: \"litellm --model claude-instant-1 --deploy\" Get an https://api.litellm.ai/chat/completions endpoint \033[0m\n")
|
print(f"\033[32mLiteLLM: Deploy your proxy using the following: \"litellm --model claude-instant-1 --deploy\" Get an https://api.litellm.ai/chat/completions endpoint \033[0m\n")
|
||||||
|
|
||||||
uvicorn.run(app, host='0.0.0.0', port=port)
|
uvicorn.run(app, host='0.0.0.0', port=port)
|
||||||
|
|
|
@ -18,10 +18,12 @@ print()
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.routing import APIRouter
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
user_api_base = None
|
user_api_base = None
|
||||||
user_model = None
|
user_model = None
|
||||||
|
@ -109,14 +111,14 @@ def data_generator(response):
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
#### API ENDPOINTS ####
|
#### API ENDPOINTS ####
|
||||||
@app.get("/models") # if project requires model list
|
@router.get("/models") # if project requires model list
|
||||||
def model_list():
|
def model_list():
|
||||||
return dict(
|
return dict(
|
||||||
data=[{"id": user_model, "object": "model", "created": 1677610602, "owned_by": "openai"}],
|
data=[{"id": user_model, "object": "model", "created": 1677610602, "owned_by": "openai"}],
|
||||||
object="list",
|
object="list",
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.post("/{version}/completions")
|
@router.post("/completions")
|
||||||
async def completion(request: Request):
|
async def completion(request: Request):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
print_verbose(f"data passed in: {data}")
|
print_verbose(f"data passed in: {data}")
|
||||||
|
@ -149,7 +151,7 @@ async def completion(request: Request):
|
||||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@app.post("/chat/completions")
|
@router.post("/chat/completions")
|
||||||
async def chat_completion(request: Request):
|
async def chat_completion(request: Request):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
print_verbose(f"data passed in: {data}")
|
print_verbose(f"data passed in: {data}")
|
||||||
|
@ -187,3 +189,5 @@ async def chat_completion(request: Request):
|
||||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||||
print_verbose(f"response: {response}")
|
print_verbose(f"response: {response}")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
app.include_router(router)
|
|
@ -709,7 +709,7 @@ def test_completion_sagemaker_stream():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
test_completion_sagemaker_stream()
|
# test_completion_sagemaker_stream()
|
||||||
|
|
||||||
# test on openai completion call
|
# test on openai completion call
|
||||||
def test_openai_text_completion_call():
|
def test_openai_text_completion_call():
|
||||||
|
@ -732,8 +732,16 @@ def test_openai_text_completion_call():
|
||||||
# # test on ai21 completion call
|
# # test on ai21 completion call
|
||||||
def ai21_completion_call():
|
def ai21_completion_call():
|
||||||
try:
|
try:
|
||||||
|
messages=[{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are an all-knowing oracle",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the meaning of the Universe?"
|
||||||
|
}]
|
||||||
response = completion(
|
response = completion(
|
||||||
model="j2-ultra", messages=messages, stream=True
|
model="j2-ultra", messages=messages, stream=True, max_tokens=500
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
has_finished = False
|
has_finished = False
|
||||||
|
@ -1262,3 +1270,31 @@ def test_openai_streaming_and_function_calling():
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# test_openai_streaming_and_function_calling()
|
# test_openai_streaming_and_function_calling()
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
def test_success_callback_streaming():
|
||||||
|
def success_callback(kwargs, completion_response, start_time, end_time):
|
||||||
|
print(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"input": kwargs,
|
||||||
|
"output": completion_response,
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
litellm.success_callback = [success_callback]
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "hello"}]
|
||||||
|
|
||||||
|
response = litellm.completion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk["choices"][0])
|
||||||
|
|
||||||
|
test_success_callback_streaming()
|
|
@ -456,6 +456,14 @@ class Logging:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
|
if callable(callback): # custom logger functions
|
||||||
|
customLogger.log_event(
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
response_obj=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -2022,14 +2030,6 @@ def handle_success(args, kwargs, result, start_time, end_time):
|
||||||
litellm_call_id=kwargs["litellm_call_id"],
|
litellm_call_id=kwargs["litellm_call_id"],
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
elif callable(callback): # custom logger functions
|
|
||||||
customLogger.log_event(
|
|
||||||
kwargs=kwargs,
|
|
||||||
response_obj=result,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# LOGGING
|
# LOGGING
|
||||||
exception_logging(logger_fn=user_logger_fn, exception=e)
|
exception_logging(logger_fn=user_logger_fn, exception=e)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue