adding custom prompt templates to ollama

This commit is contained in:
Krrish Dholakia 2023-10-05 10:47:51 -07:00
parent 966ad27662
commit ed31860206
7 changed files with 164 additions and 85 deletions

View file

@ -151,6 +151,38 @@ $ litellm --model command-nightly
### Deploy Proxy
<Tabs>
<TabItem value="self-hosted" label="Self-Hosted">
**Step 1: Clone the repo**
```shell
git clone https://github.com/BerriAI/liteLLM-proxy.git
```
**Step 2: Put your API keys in .env**
Copy the .env.template and put in the relevant keys (e.g. OPENAI_API_KEY="sk-..")
**Step 3: Test your proxy**
Start your proxy server
```shell
cd litellm-proxy && python3 main.py
```
Make your first call
```python
import openai
openai.api_key = "sk-litellm-master-key"
openai.api_base = "http://0.0.0.0:8080"
response = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey"}])
print(response)
```
</TabItem>
<TabItem value="litellm-hosted" label="LiteLLM-Hosted">
Deploy the proxy to https://api.litellm.ai
```shell
@ -161,7 +193,6 @@ $ litellm --model claude-instant-1 --deploy
```
This will host a ChatCompletions API at: https://api.litellm.ai/44508ad4
#### Other supported models:
<Tabs>
<TabItem value="anthropic" label="Anthropic">
@ -280,6 +311,8 @@ curl --location 'https://api.litellm.ai/44508ad4/chat/completions' \
```
</TabItem>
</Tabs>
</TabItem>
</Tabs>
## Setting api base, temperature, max tokens

View file

@ -29,22 +29,24 @@ from litellm.utils import (
get_api_key,
mock_completion_streaming_obj,
)
from .llms import anthropic
from .llms import together_ai
from .llms import ai21
from .llms import sagemaker
from .llms import bedrock
from .llms import huggingface_restapi
from .llms import replicate
from .llms import aleph_alpha
from .llms import nlp_cloud
from .llms import baseten
from .llms import vllm
from .llms import ollama
from .llms import cohere
from .llms import petals
from .llms import oobabooga
from .llms import palm
from .llms import (
anthropic,
together_ai,
ai21,
sagemaker,
bedrock,
huggingface_restapi,
replicate,
aleph_alpha,
nlp_cloud,
baseten,
vllm,
ollama,
cohere,
petals,
oobabooga,
palm)
from .llms.prompt_templates.factory import prompt_factory, custom_prompt
import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict
@ -1040,13 +1042,25 @@ def completion(
response = model_response
elif custom_llm_provider == "ollama":
endpoint = (
litellm.api_base if litellm.api_base is not None else api_base
litellm.api_base
or api_base
or "http://localhost:11434"
)
prompt = " ".join([message["content"] for message in messages])
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
)
else:
prompt = prompt_factory(model=model, messages=messages)
## LOGGING
logging.pre_call(
input=prompt, api_key=None, additional_args={"endpoint": endpoint}
input=prompt, api_key=None, additional_args={"endpoint": endpoint, "custom_prompt_dict": litellm.custom_prompt_dict}
)
if kwargs.get('acompletion', False) == True:
async_generator = ollama.async_get_ollama_response_stream(endpoint, model, prompt)

View file

@ -1,5 +1,5 @@
# ##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
# # https://ollama.ai/
# # ##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
# # # https://ollama.ai/
# import sys, os
# import traceback
@ -16,27 +16,55 @@
# user_message = "respond in 20 words. who are you?"
# messages = [{ "content": user_message,"role": "user"}]
# def test_completion_ollama():
# try:
# response = completion(
# model="ollama/llama2",
# messages=messages,
# api_base="http://localhost:11434"
# )
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# # def test_completion_ollama():
# # try:
# # response = completion(
# # model="ollama/llama2",
# # messages=messages,
# # api_base="http://localhost:11434"
# # )
# # print(response)
# # except Exception as e:
# # pytest.fail(f"Error occurred: {e}")
# # test_completion_ollama()
# def test_completion_ollama_stream():
# # def test_completion_ollama_stream():
# # user_message = "what is litellm?"
# # messages = [{ "content": user_message,"role": "user"}]
# # try:
# # response = completion(
# # model="ollama/llama2",
# # messages=messages,
# # stream=True
# # )
# # print(response)
# # for chunk in response:
# # print(chunk)
# # # print(chunk['choices'][0]['delta'])
# # except Exception as e:
# # pytest.fail(f"Error occurred: {e}")
# # test_completion_ollama_stream()
# def test_completion_ollama_custom_prompt_template():
# user_message = "what is litellm?"
# litellm.register_prompt_template(
# model="llama2",
# roles={
# "system": {"pre_message": "System: "},
# "user": {"pre_message": "User: "},
# "assistant": {"pre_message": "Assistant: "}
# }
# )
# messages = [{ "content": user_message,"role": "user"}]
# litellm.set_verbose = True
# try:
# response = completion(
# model="ollama/llama2",
# messages=messages,
# api_base="http://localhost:11434",
# stream=True
# )
# print(response)
@ -45,54 +73,54 @@
# # print(chunk['choices'][0]['delta'])
# except Exception as e:
# traceback.print_exc()
# pytest.fail(f"Error occurred: {e}")
# test_completion_ollama_stream()
# test_completion_ollama_custom_prompt_template()
# # async def test_completion_ollama_async_stream():
# # user_message = "what is the weather"
# # messages = [{ "content": user_message,"role": "user"}]
# # try:
# # response = await litellm.acompletion(
# # model="ollama/llama2",
# # messages=messages,
# # api_base="http://localhost:11434",
# # stream=True
# # )
# # async for chunk in response:
# # print(chunk)
# # # print(chunk['choices'][0]['delta'])
# # except Exception as e:
# # pytest.fail(f"Error occurred: {e}")
# # # import asyncio
# # # asyncio.run(test_completion_ollama_async_stream())
# # def prepare_messages_for_chat(text: str) -> list:
# # messages = [
# # {"role": "user", "content": text},
# # ]
# # return messages
# async def test_completion_ollama_async_stream():
# user_message = "what is the weather"
# messages = [{ "content": user_message,"role": "user"}]
# try:
# response = await litellm.acompletion(
# model="ollama/llama2",
# messages=messages,
# api_base="http://localhost:11434",
# stream=True
# )
# async for chunk in response:
# print(chunk)
# # async def ask_question():
# # params = {
# # "messages": prepare_messages_for_chat("What is litellm? tell me 10 things about it who is sihaan.write an essay"),
# # "api_base": "http://localhost:11434",
# # "model": "ollama/llama2",
# # "stream": True,
# # }
# # response = await litellm.acompletion(**params)
# # return response
# # print(chunk['choices'][0]['delta'])
# # async def main():
# # response = await ask_question()
# # async for chunk in response:
# # print(chunk)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# # import asyncio
# # asyncio.run(test_completion_ollama_async_stream())
# def prepare_messages_for_chat(text: str) -> list:
# messages = [
# {"role": "user", "content": text},
# ]
# return messages
# async def ask_question():
# params = {
# "messages": prepare_messages_for_chat("What is litellm? tell me 10 things about it who is sihaan.write an essay"),
# "api_base": "http://localhost:11434",
# "model": "ollama/llama2",
# "stream": True,
# }
# response = await litellm.acompletion(**params)
# return response
# async def main():
# response = await ask_question()
# async for chunk in response:
# print(chunk)
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())
# # if __name__ == "__main__":
# # import asyncio
# # asyncio.run(main())

View file

@ -2081,24 +2081,28 @@ def modify_integration(integration_name, integration_params):
# custom prompt helper function
def register_prompt_template(model: str, roles: dict, initial_prompt_value: str = "", final_prompt_value: str = ""):
"""
Format the openai prompt, to follow your custom format.
Example usage:
```
import litellm
litellm.register_prompt_template(
model="llama-2",
initial_prompt_value="You are a good assistant" # [OPTIONAL]
roles={
"system": {
"pre_message": "[INST] <<SYS>>\n",
"post_message": "\n<</SYS>>\n [/INST]\n"
"pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
"post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
},
"user": { # follow this format https://github.com/facebookresearch/llama/blob/77062717054710e352a99add63d160274ce670c6/llama/generation.py#L348
"pre_message": "[INST] ",
"post_message": " [/INST]\n"
"user": {
"pre_message": "[INST] ", # [OPTIONAL]
"post_message": " [/INST]" # [OPTIONAL]
},
"assistant": {
"post_message": "\n" # follows this - https://replicate.com/blog/how-to-prompt-llama
"pre_message": "\n" # [OPTIONAL]
"post_message": "\n" # [OPTIONAL]
}
}
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
)
```
"""