adding custom prompt templates to ollama

This commit is contained in:
Krrish Dholakia 2023-10-05 10:47:51 -07:00
parent 966ad27662
commit ed31860206
7 changed files with 164 additions and 85 deletions

View file

@ -151,6 +151,38 @@ $ litellm --model command-nightly
### Deploy Proxy ### Deploy Proxy
<Tabs>
<TabItem value="self-hosted" label="Self-Hosted">
**Step 1: Clone the repo**
```shell
git clone https://github.com/BerriAI/liteLLM-proxy.git
```
**Step 2: Put your API keys in .env**
Copy the .env.template and put in the relevant keys (e.g. OPENAI_API_KEY="sk-..")
**Step 3: Test your proxy**
Start your proxy server
```shell
cd litellm-proxy && python3 main.py
```
Make your first call
```python
import openai
openai.api_key = "sk-litellm-master-key"
openai.api_base = "http://0.0.0.0:8080"
response = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey"}])
print(response)
```
</TabItem>
<TabItem value="litellm-hosted" label="LiteLLM-Hosted">
Deploy the proxy to https://api.litellm.ai Deploy the proxy to https://api.litellm.ai
```shell ```shell
@ -161,7 +193,6 @@ $ litellm --model claude-instant-1 --deploy
``` ```
This will host a ChatCompletions API at: https://api.litellm.ai/44508ad4 This will host a ChatCompletions API at: https://api.litellm.ai/44508ad4
#### Other supported models: #### Other supported models:
<Tabs> <Tabs>
<TabItem value="anthropic" label="Anthropic"> <TabItem value="anthropic" label="Anthropic">
@ -280,6 +311,8 @@ curl --location 'https://api.litellm.ai/44508ad4/chat/completions' \
``` ```
</TabItem> </TabItem>
</Tabs> </Tabs>
</TabItem>
</Tabs>
## Setting api base, temperature, max tokens ## Setting api base, temperature, max tokens

View file

@ -29,22 +29,24 @@ from litellm.utils import (
get_api_key, get_api_key,
mock_completion_streaming_obj, mock_completion_streaming_obj,
) )
from .llms import anthropic from .llms import (
from .llms import together_ai anthropic,
from .llms import ai21 together_ai,
from .llms import sagemaker ai21,
from .llms import bedrock sagemaker,
from .llms import huggingface_restapi bedrock,
from .llms import replicate huggingface_restapi,
from .llms import aleph_alpha replicate,
from .llms import nlp_cloud aleph_alpha,
from .llms import baseten nlp_cloud,
from .llms import vllm baseten,
from .llms import ollama vllm,
from .llms import cohere ollama,
from .llms import petals cohere,
from .llms import oobabooga petals,
from .llms import palm oobabooga,
palm)
from .llms.prompt_templates.factory import prompt_factory, custom_prompt
import tiktoken import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict from typing import Callable, List, Optional, Dict
@ -1040,13 +1042,25 @@ def completion(
response = model_response response = model_response
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
endpoint = ( endpoint = (
litellm.api_base if litellm.api_base is not None else api_base litellm.api_base
or api_base
or "http://localhost:11434"
) )
prompt = " ".join([message["content"] for message in messages]) if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
)
else:
prompt = prompt_factory(model=model, messages=messages)
## LOGGING ## LOGGING
logging.pre_call( logging.pre_call(
input=prompt, api_key=None, additional_args={"endpoint": endpoint} input=prompt, api_key=None, additional_args={"endpoint": endpoint, "custom_prompt_dict": litellm.custom_prompt_dict}
) )
if kwargs.get('acompletion', False) == True: if kwargs.get('acompletion', False) == True:
async_generator = ollama.async_get_ollama_response_stream(endpoint, model, prompt) async_generator = ollama.async_get_ollama_response_stream(endpoint, model, prompt)

View file

@ -1,5 +1,5 @@
# ##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ###### # # ##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
# # https://ollama.ai/ # # # https://ollama.ai/
# import sys, os # import sys, os
# import traceback # import traceback
@ -16,27 +16,55 @@
# user_message = "respond in 20 words. who are you?" # user_message = "respond in 20 words. who are you?"
# messages = [{ "content": user_message,"role": "user"}] # messages = [{ "content": user_message,"role": "user"}]
# def test_completion_ollama(): # # def test_completion_ollama():
# try: # # try:
# response = completion( # # response = completion(
# model="ollama/llama2", # # model="ollama/llama2",
# messages=messages, # # messages=messages,
# api_base="http://localhost:11434" # # api_base="http://localhost:11434"
# ) # # )
# print(response) # # print(response)
# except Exception as e: # # except Exception as e:
# pytest.fail(f"Error occurred: {e}") # # pytest.fail(f"Error occurred: {e}")
# # test_completion_ollama() # # test_completion_ollama()
# def test_completion_ollama_stream(): # # def test_completion_ollama_stream():
# # user_message = "what is litellm?"
# # messages = [{ "content": user_message,"role": "user"}]
# # try:
# # response = completion(
# # model="ollama/llama2",
# # messages=messages,
# # stream=True
# # )
# # print(response)
# # for chunk in response:
# # print(chunk)
# # # print(chunk['choices'][0]['delta'])
# # except Exception as e:
# # pytest.fail(f"Error occurred: {e}")
# # test_completion_ollama_stream()
# def test_completion_ollama_custom_prompt_template():
# user_message = "what is litellm?" # user_message = "what is litellm?"
# litellm.register_prompt_template(
# model="llama2",
# roles={
# "system": {"pre_message": "System: "},
# "user": {"pre_message": "User: "},
# "assistant": {"pre_message": "Assistant: "}
# }
# )
# messages = [{ "content": user_message,"role": "user"}] # messages = [{ "content": user_message,"role": "user"}]
# litellm.set_verbose = True
# try: # try:
# response = completion( # response = completion(
# model="ollama/llama2", # model="ollama/llama2",
# messages=messages, # messages=messages,
# api_base="http://localhost:11434",
# stream=True # stream=True
# ) # )
# print(response) # print(response)
@ -45,54 +73,54 @@
# # print(chunk['choices'][0]['delta']) # # print(chunk['choices'][0]['delta'])
# except Exception as e: # except Exception as e:
# traceback.print_exc()
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# test_completion_ollama_stream() # test_completion_ollama_custom_prompt_template()
# # async def test_completion_ollama_async_stream():
# # user_message = "what is the weather"
# # messages = [{ "content": user_message,"role": "user"}]
# # try:
# # response = await litellm.acompletion(
# # model="ollama/llama2",
# # messages=messages,
# # api_base="http://localhost:11434",
# # stream=True
# # )
# # async for chunk in response:
# # print(chunk)
# # # print(chunk['choices'][0]['delta'])
# # except Exception as e:
# # pytest.fail(f"Error occurred: {e}")
# # # import asyncio
# # # asyncio.run(test_completion_ollama_async_stream())
# # def prepare_messages_for_chat(text: str) -> list:
# # messages = [
# # {"role": "user", "content": text},
# # ]
# # return messages
# async def test_completion_ollama_async_stream(): # # async def ask_question():
# user_message = "what is the weather" # # params = {
# messages = [{ "content": user_message,"role": "user"}] # # "messages": prepare_messages_for_chat("What is litellm? tell me 10 things about it who is sihaan.write an essay"),
# try: # # "api_base": "http://localhost:11434",
# response = await litellm.acompletion( # # "model": "ollama/llama2",
# model="ollama/llama2", # # "stream": True,
# messages=messages, # # }
# api_base="http://localhost:11434", # # response = await litellm.acompletion(**params)
# stream=True # # return response
# )
# async for chunk in response:
# print(chunk)
# # print(chunk['choices'][0]['delta']) # # async def main():
# # response = await ask_question()
# # async for chunk in response:
# # print(chunk)
# except Exception as e: # # if __name__ == "__main__":
# pytest.fail(f"Error occurred: {e}") # # import asyncio
# # asyncio.run(main())
# # import asyncio
# # asyncio.run(test_completion_ollama_async_stream())
# def prepare_messages_for_chat(text: str) -> list:
# messages = [
# {"role": "user", "content": text},
# ]
# return messages
# async def ask_question():
# params = {
# "messages": prepare_messages_for_chat("What is litellm? tell me 10 things about it who is sihaan.write an essay"),
# "api_base": "http://localhost:11434",
# "model": "ollama/llama2",
# "stream": True,
# }
# response = await litellm.acompletion(**params)
# return response
# async def main():
# response = await ask_question()
# async for chunk in response:
# print(chunk)
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())

View file

@ -2081,24 +2081,28 @@ def modify_integration(integration_name, integration_params):
# custom prompt helper function # custom prompt helper function
def register_prompt_template(model: str, roles: dict, initial_prompt_value: str = "", final_prompt_value: str = ""): def register_prompt_template(model: str, roles: dict, initial_prompt_value: str = "", final_prompt_value: str = ""):
""" """
Format the openai prompt, to follow your custom format.
Example usage: Example usage:
``` ```
import litellm import litellm
litellm.register_prompt_template( litellm.register_prompt_template(
model="llama-2", model="llama-2",
initial_prompt_value="You are a good assistant" # [OPTIONAL]
roles={ roles={
"system": { "system": {
"pre_message": "[INST] <<SYS>>\n", "pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
"post_message": "\n<</SYS>>\n [/INST]\n" "post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
}, },
"user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348 "user": {
"pre_message": "[INST] ", "pre_message": "[INST] ", # [OPTIONAL]
"post_message": " [/INST]\n" "post_message": " [/INST]" # [OPTIONAL]
}, },
"assistant": { "assistant": {
"post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama "pre_message": "\n" # [OPTIONAL]
"post_message": "\n" # [OPTIONAL]
} }
} }
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
) )
``` ```
""" """