mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(factory.py): option to add function details to prompt, if model doesn't support functions param
This commit is contained in:
parent
f6f7c0b891
commit
704be9dcd1
8 changed files with 130 additions and 27 deletions
|
@ -28,7 +28,7 @@ from litellm.utils import (
|
|||
completion_with_fallbacks,
|
||||
get_llm_provider,
|
||||
get_api_key,
|
||||
mock_completion_streaming_obj,
|
||||
mock_completion_streaming_obj
|
||||
)
|
||||
from .llms import (
|
||||
anthropic,
|
||||
|
@ -48,7 +48,7 @@ from .llms import (
|
|||
oobabooga,
|
||||
palm,
|
||||
vertex_ai)
|
||||
from .llms.prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
|
||||
import tiktoken
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, List, Optional, Dict
|
||||
|
@ -259,27 +259,33 @@ def completion(
|
|||
api_base = "https://proxy.litellm.ai"
|
||||
custom_llm_provider = "openai"
|
||||
api_key = model_api_key
|
||||
|
||||
# check if user passed in any of the OpenAI optional params
|
||||
optional_params = get_optional_params(
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
user=user,
|
||||
request_timeout=request_timeout,
|
||||
deployment_id=deployment_id,
|
||||
# params to identify the model
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**non_default_params
|
||||
)
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
user=user,
|
||||
request_timeout=request_timeout,
|
||||
deployment_id=deployment_id,
|
||||
# params to identify the model
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**non_default_params
|
||||
)
|
||||
|
||||
if litellm.add_function_to_prompt and optional_params.get("functions_unsupported_model", None): # if user opts to add it to prompt, when API doesn't support function calling
|
||||
functions_unsupported_model = optional_params.pop("functions_unsupported_model")
|
||||
messages = function_call_prompt(messages=messages, functions=functions_unsupported_model)
|
||||
|
||||
# For logging - save the values of the litellm-specific params passed in
|
||||
litellm_params = get_litellm_params(
|
||||
return_async=return_async,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue