mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(proxy_server): fix prompt template for proxy server
This commit is contained in:
parent
413097f0a7
commit
4b0f8825f3
4 changed files with 24 additions and 22 deletions
Binary file not shown.
Binary file not shown.
|
@ -245,9 +245,8 @@ def deploy_proxy(model, api_base, debug, temperature, max_tokens, telemetry, dep
|
|||
|
||||
# for streaming
|
||||
def data_generator(response):
|
||||
print("inside generator")
|
||||
print_verbose("inside generator")
|
||||
for chunk in response:
|
||||
print(f"chunk: {chunk}")
|
||||
print_verbose(f"returned chunk: {chunk}")
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
|
@ -302,26 +301,6 @@ def litellm_completion(data, type):
|
|||
data["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
data["api_base"] = user_api_base
|
||||
## CUSTOM PROMPT TEMPLATE ## - run `litellm --config` to set this
|
||||
litellm.register_prompt_template(
|
||||
model=user_model,
|
||||
roles={
|
||||
"system": {
|
||||
"pre_message": os.getenv("MODEL_SYSTEM_MESSAGE_START_TOKEN", ""),
|
||||
"post_message": os.getenv("MODEL_SYSTEM_MESSAGE_END_TOKEN", ""),
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": os.getenv("MODEL_ASSISTANT_MESSAGE_START_TOKEN", ""),
|
||||
"post_message": os.getenv("MODEL_ASSISTANT_MESSAGE_END_TOKEN", "")
|
||||
},
|
||||
"user": {
|
||||
"pre_message": os.getenv("MODEL_USER_MESSAGE_START_TOKEN", ""),
|
||||
"post_message": os.getenv("MODEL_USER_MESSAGE_END_TOKEN", "")
|
||||
}
|
||||
},
|
||||
initial_prompt_value=os.getenv("MODEL_PRE_PROMPT", ""),
|
||||
final_prompt_value=os.getenv("MODEL_POST_PROMPT", "")
|
||||
)
|
||||
if type == "completion":
|
||||
response = litellm.text_completion(**data)
|
||||
elif type == "chat_completion":
|
||||
|
|
23
litellm/tests/test_prompt_factory.py
Normal file
23
litellm/tests/test_prompt_factory.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
#### What this tests ####
|
||||
# This tests if prompts are being correctly formatted
|
||||
import sys
|
||||
import os
|
||||
import io
|
||||
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
# from litellm.llms.prompt_templates.factory import prompt_factory
|
||||
from litellm import completion
|
||||
|
||||
def codellama_prompt_format():
|
||||
model = "huggingface/codellama/CodeLlama-7b-Instruct-hf"
|
||||
messages = [{"role": "system", "content": "You are a good bot"}, {"role": "user", "content": "Hey, how's it going?"}]
|
||||
expected_response = """[INST] <<SYS>>
|
||||
You are a good bot
|
||||
<</SYS>>
|
||||
[/INST]
|
||||
[INST] Hey, how's it going? [/INST]"""
|
||||
response = completion(model=model, messages=messages)
|
||||
print(response)
|
||||
|
||||
# codellama_prompt_format()
|
Loading…
Add table
Add a link
Reference in a new issue