adding custom prompt templates to ollama

This commit is contained in:
Krrish Dholakia 2023-10-05 10:47:51 -07:00
parent 966ad27662
commit ed31860206
7 changed files with 164 additions and 85 deletions

View file

@ -29,22 +29,24 @@ from litellm.utils import (
get_api_key,
mock_completion_streaming_obj,
)
from .llms import anthropic
from .llms import together_ai
from .llms import ai21
from .llms import sagemaker
from .llms import bedrock
from .llms import huggingface_restapi
from .llms import replicate
from .llms import aleph_alpha
from .llms import nlp_cloud
from .llms import baseten
from .llms import vllm
from .llms import ollama
from .llms import cohere
from .llms import petals
from .llms import oobabooga
from .llms import palm
from .llms import (
anthropic,
together_ai,
ai21,
sagemaker,
bedrock,
huggingface_restapi,
replicate,
aleph_alpha,
nlp_cloud,
baseten,
vllm,
ollama,
cohere,
petals,
oobabooga,
palm)
from .llms.prompt_templates.factory import prompt_factory, custom_prompt
import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict
@ -1040,13 +1042,25 @@ def completion(
response = model_response
elif custom_llm_provider == "ollama":
endpoint = (
litellm.api_base if litellm.api_base is not None else api_base
litellm.api_base
or api_base
or "http://localhost:11434"
)
prompt = " ".join([message["content"] for message in messages])
if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = litellm.custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages
)
else:
prompt = prompt_factory(model=model, messages=messages)
## LOGGING
logging.pre_call(
input=prompt, api_key=None, additional_args={"endpoint": endpoint}
input=prompt, api_key=None, additional_args={"endpoint": endpoint, "custom_prompt_dict": litellm.custom_prompt_dict}
)
if kwargs.get('acompletion', False) == True:
async_generator = ollama.async_get_ollama_response_stream(endpoint, model, prompt)