fix(utils.py): support together ai function calling

This commit is contained in:
Krrish Dholakia 2024-02-05 15:30:30 -08:00
parent 006b5efef0
commit 77fe71ee08
4 changed files with 16 additions and 9 deletions

1
.gitignore vendored
View file

@ -43,3 +43,4 @@ ui/litellm-dashboard/package-lock.json
deploy/charts/litellm-helm/*.tgz deploy/charts/litellm-helm/*.tgz
deploy/charts/litellm-helm/charts/* deploy/charts/litellm-helm/charts/*
deploy/charts/*.tgz deploy/charts/*.tgz
litellm/proxy/vertex_key.json

View file

@ -263,6 +263,7 @@ async def acompletion(
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "ollama_chat" or custom_llm_provider == "ollama_chat"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance( if isinstance(init_response, dict) or isinstance(

View file

@ -758,9 +758,10 @@ async def _PROXY_track_cost_callback(
verbose_proxy_logger.info( verbose_proxy_logger.info(
f"response_cost {response_cost}, for user_id {user_id}" f"response_cost {response_cost}, for user_id {user_id}"
) )
if user_api_key and ( verbose_proxy_logger.debug(
prisma_client is not None or custom_db_client is not None f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}"
): )
if user_api_key is not None:
await update_database( await update_database(
token=user_api_key, token=user_api_key,
response_cost=response_cost, response_cost=response_cost,
@ -770,6 +771,8 @@ async def _PROXY_track_cost_callback(
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) )
else:
raise Exception("User API key missing from custom callback.")
else: else:
if kwargs["stream"] != True or ( if kwargs["stream"] != True or (
kwargs["stream"] == True kwargs["stream"] == True
@ -4067,7 +4070,6 @@ def _has_user_setup_sso():
async def shutdown_event(): async def shutdown_event():
global prisma_client, master_key, user_custom_auth, user_custom_key_generate global prisma_client, master_key, user_custom_auth, user_custom_key_generate
if prisma_client: if prisma_client:
verbose_proxy_logger.debug("Disconnecting from Prisma") verbose_proxy_logger.debug("Disconnecting from Prisma")
await prisma_client.disconnect() await prisma_client.disconnect()

View file

@ -3852,6 +3852,8 @@ def get_optional_params(
and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "text-completion-openai"
and custom_llm_provider != "azure" and custom_llm_provider != "azure"
and custom_llm_provider != "vertex_ai" and custom_llm_provider != "vertex_ai"
and custom_llm_provider != "anyscale"
and custom_llm_provider != "together_ai"
): ):
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat": if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
# ollama actually supports json output # ollama actually supports json output
@ -3870,11 +3872,6 @@ def get_optional_params(
optional_params[ optional_params[
"functions_unsupported_model" "functions_unsupported_model"
] = non_default_params.pop("functions") ] = non_default_params.pop("functions")
elif (
custom_llm_provider == "anyscale"
and model == "mistralai/Mistral-7B-Instruct-v0.1"
): # anyscale just supports function calling with mistral
pass
elif ( elif (
litellm.add_function_to_prompt litellm.add_function_to_prompt
): # if user opts to add it to prompt instead ): # if user opts to add it to prompt instead
@ -4087,6 +4084,8 @@ def get_optional_params(
"top_p", "top_p",
"stop", "stop",
"frequency_penalty", "frequency_penalty",
"tools",
"tool_choice",
] ]
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
@ -4104,6 +4103,10 @@ def get_optional_params(
] = frequency_penalty # https://docs.together.ai/reference/inference ] = frequency_penalty # https://docs.together.ai/reference/inference
if stop is not None: if stop is not None:
optional_params["stop"] = stop optional_params["stop"] = stop
if tools is not None:
optional_params["tools"] = tools
if tool_choice is not None:
optional_params["tool_choice"] = tool_choice
elif custom_llm_provider == "ai21": elif custom_llm_provider == "ai21":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = [ supported_params = [