feat(factory.py): option to add function details to prompt, if model doesn't support functions param

This commit is contained in:
Krrish Dholakia 2023-10-09 09:53:31 -07:00
parent f6f7c0b891
commit 704be9dcd1
8 changed files with 130 additions and 27 deletions

View file

@ -28,7 +28,7 @@ from litellm.utils import (
completion_with_fallbacks,
get_llm_provider,
get_api_key,
mock_completion_streaming_obj,
mock_completion_streaming_obj
)
from .llms import (
anthropic,
@ -48,7 +48,7 @@ from .llms import (
oobabooga,
palm,
vertex_ai)
from .llms.prompt_templates.factory import prompt_factory, custom_prompt
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict
@ -259,27 +259,33 @@ def completion(
api_base = "https://proxy.litellm.ai"
custom_llm_provider = "openai"
api_key = model_api_key
# check if user passed in any of the OpenAI optional params
optional_params = get_optional_params(
functions=functions,
function_call=function_call,
temperature=temperature,
top_p=top_p,
n=n,
stream=stream,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
request_timeout=request_timeout,
deployment_id=deployment_id,
# params to identify the model
model=model,
custom_llm_provider=custom_llm_provider,
**non_default_params
)
functions=functions,
function_call=function_call,
temperature=temperature,
top_p=top_p,
n=n,
stream=stream,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
request_timeout=request_timeout,
deployment_id=deployment_id,
# params to identify the model
model=model,
custom_llm_provider=custom_llm_provider,
**non_default_params
)
if litellm.add_function_to_prompt and optional_params.get("functions_unsupported_model", None): # if user opts to add it to prompt, when API doesn't support function calling
functions_unsupported_model = optional_params.pop("functions_unsupported_model")
messages = function_call_prompt(messages=messages, functions=functions_unsupported_model)
# For logging - save the values of the litellm-specific params passed in
litellm_params = get_litellm_params(
return_async=return_async,