mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix(add-custom-success-callback-for-streaming): add custom success callback for streaming
This commit is contained in:
parent
868c1c594f
commit
7e34736a38
8 changed files with 89 additions and 20 deletions
Binary file not shown.
Binary file not shown.
|
@ -1,6 +1,7 @@
|
|||
import requests, traceback
|
||||
import json
|
||||
from jinja2 import Template, exceptions, Environment, meta
|
||||
from typing import Optional
|
||||
|
||||
def default_pt(messages):
|
||||
return " ".join(message["content"] for message in messages)
|
||||
|
@ -26,6 +27,25 @@ def llama_2_chat_pt(messages):
|
|||
)
|
||||
return prompt
|
||||
|
||||
def ollama_pt(messages): # https://github.com/jmorganca/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
||||
prompt = custom_prompt(
|
||||
role_dict={
|
||||
"system": {
|
||||
"pre_message": "### System:\n",
|
||||
"post_message": "\n"
|
||||
},
|
||||
"user": {
|
||||
"pre_message": "### User:\n",
|
||||
"post_message": "\n",
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "### Response:\n",
|
||||
"post_message": "\n",
|
||||
}
|
||||
},
|
||||
final_prompt_value="### Response:"
|
||||
)
|
||||
|
||||
def mistral_instruct_pt(messages):
|
||||
prompt = custom_prompt(
|
||||
initial_prompt_value="<s>",
|
||||
|
@ -190,9 +210,13 @@ def custom_prompt(role_dict: dict, messages: list, initial_prompt_value: str="",
|
|||
prompt += final_prompt_value
|
||||
return prompt
|
||||
|
||||
def prompt_factory(model: str, messages: list):
|
||||
def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str]=None):
|
||||
original_model_name = model
|
||||
model = model.lower()
|
||||
|
||||
if custom_llm_provider == "ollama":
|
||||
return ollama_pt(messages=messages)
|
||||
|
||||
try:
|
||||
if "meta-llama/llama-2" in model:
|
||||
if "chat" in model:
|
||||
|
|
|
@ -961,7 +961,7 @@ def completion(
|
|||
messages=messages
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
|
||||
|
||||
## LOGGING
|
||||
logging.pre_call(
|
||||
|
@ -1410,6 +1410,7 @@ def text_completion(*args, **kwargs):
|
|||
kwargs["messages"] = messages
|
||||
kwargs.pop("prompt")
|
||||
response = completion(*args, **kwargs) # assume the response is the openai response object
|
||||
print(f"response: {response}")
|
||||
formatted_response_obj = {
|
||||
"id": response["id"],
|
||||
"object": "text_completion",
|
||||
|
|
|
@ -61,7 +61,7 @@ def open_config():
|
|||
@click.option('--max_tokens', default=None, help='Set max tokens for the model')
|
||||
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
|
||||
@click.option('--config', is_flag=True, help='Create and open .env file from .env.template')
|
||||
@click.option('--test', default=None, help='proxy chat completions url to make a test request to')
|
||||
@click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to')
|
||||
@click.option('--local', is_flag=True, default=False, help='for local debugging')
|
||||
def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, telemetry, config, test, local):
|
||||
if config:
|
||||
|
@ -82,10 +82,14 @@ def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, te
|
|||
|
||||
print(f"\033[32mLiteLLM: Test your URL using the following: \"litellm --test {url}\"\033[0m")
|
||||
return
|
||||
if test != None:
|
||||
if test != False:
|
||||
click.echo('LiteLLM: Making a test ChatCompletions request to your proxy')
|
||||
import openai
|
||||
openai.api_base = test
|
||||
if test == True: # flag value set
|
||||
api_base = "http://0.0.0.0:8000"
|
||||
else:
|
||||
api_base = test
|
||||
openai.api_base = api_base
|
||||
openai.api_key = "temp-key"
|
||||
print(openai.api_base)
|
||||
|
||||
|
@ -107,7 +111,7 @@ def run_server(port, api_base, model, deploy, debug, temperature, max_tokens, te
|
|||
except:
|
||||
raise ImportError("Uvicorn needs to be imported. Run - `pip install uvicorn`")
|
||||
print(f"\033[32mLiteLLM: Deployed Proxy Locally\033[0m\n")
|
||||
print(f"\033[32mLiteLLM: Test your URL using the following: \"litellm --test http://0.0.0.0:{port}\" [In a new terminal tab]\033[0m\n")
|
||||
print(f"\033[32mLiteLLM: Test your local endpoint with: \"litellm --test\" [In a new terminal tab]\033[0m\n")
|
||||
print(f"\033[32mLiteLLM: Deploy your proxy using the following: \"litellm --model claude-instant-1 --deploy\" Get an https://api.litellm.ai/chat/completions endpoint \033[0m\n")
|
||||
|
||||
uvicorn.run(app, host='0.0.0.0', port=port)
|
||||
|
|
|
@ -18,10 +18,12 @@ print()
|
|||
|
||||
import litellm
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
app = FastAPI()
|
||||
router = APIRouter()
|
||||
|
||||
user_api_base = None
|
||||
user_model = None
|
||||
|
@ -109,14 +111,14 @@ def data_generator(response):
|
|||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
#### API ENDPOINTS ####
|
||||
@app.get("/models") # if project requires model list
|
||||
@router.get("/models") # if project requires model list
|
||||
def model_list():
|
||||
return dict(
|
||||
data=[{"id": user_model, "object": "model", "created": 1677610602, "owned_by": "openai"}],
|
||||
object="list",
|
||||
)
|
||||
|
||||
@app.post("/{version}/completions")
|
||||
@router.post("/completions")
|
||||
async def completion(request: Request):
|
||||
data = await request.json()
|
||||
print_verbose(f"data passed in: {data}")
|
||||
|
@ -149,7 +151,7 @@ async def completion(request: Request):
|
|||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||
return response
|
||||
|
||||
@app.post("/chat/completions")
|
||||
@router.post("/chat/completions")
|
||||
async def chat_completion(request: Request):
|
||||
data = await request.json()
|
||||
print_verbose(f"data passed in: {data}")
|
||||
|
@ -186,4 +188,6 @@ async def chat_completion(request: Request):
|
|||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||
print_verbose(f"response: {response}")
|
||||
return response
|
||||
return response
|
||||
|
||||
app.include_router(router)
|
|
@ -709,7 +709,7 @@ def test_completion_sagemaker_stream():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_completion_sagemaker_stream()
|
||||
# test_completion_sagemaker_stream()
|
||||
|
||||
# test on openai completion call
|
||||
def test_openai_text_completion_call():
|
||||
|
@ -732,8 +732,16 @@ def test_openai_text_completion_call():
|
|||
# # test on ai21 completion call
|
||||
def ai21_completion_call():
|
||||
try:
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You are an all-knowing oracle",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the meaning of the Universe?"
|
||||
}]
|
||||
response = completion(
|
||||
model="j2-ultra", messages=messages, stream=True
|
||||
model="j2-ultra", messages=messages, stream=True, max_tokens=500
|
||||
)
|
||||
print(f"response: {response}")
|
||||
has_finished = False
|
||||
|
@ -1262,3 +1270,31 @@ def test_openai_streaming_and_function_calling():
|
|||
raise e
|
||||
|
||||
# test_openai_streaming_and_function_calling()
|
||||
import litellm
|
||||
|
||||
|
||||
def test_success_callback_streaming():
|
||||
def success_callback(kwargs, completion_response, start_time, end_time):
|
||||
print(
|
||||
{
|
||||
"success": True,
|
||||
"input": kwargs,
|
||||
"output": completion_response,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
litellm.success_callback = [success_callback]
|
||||
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
|
||||
response = litellm.completion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
||||
print(response)
|
||||
|
||||
|
||||
for chunk in response:
|
||||
print(chunk["choices"][0])
|
||||
|
||||
test_success_callback_streaming()
|
|
@ -456,6 +456,14 @@ class Logging:
|
|||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callable(callback): # custom logger functions
|
||||
customLogger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
|
@ -2022,14 +2030,6 @@ def handle_success(args, kwargs, result, start_time, end_time):
|
|||
litellm_call_id=kwargs["litellm_call_id"],
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
elif callable(callback): # custom logger functions
|
||||
customLogger.log_event(
|
||||
kwargs=kwargs,
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
except Exception as e:
|
||||
# LOGGING
|
||||
exception_logging(logger_fn=user_logger_fn, exception=e)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue