mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(main.py): fix azure ai cohere tool calling
This commit is contained in:
parent
a8e181369d
commit
fd25117b67
3 changed files with 102 additions and 8 deletions
|
@ -1903,6 +1903,26 @@ def azure_text_pt(messages: list):
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
###### AZURE AI #######
|
||||||
|
def stringify_json_tool_call_content(messages: List) -> List:
|
||||||
|
"""
|
||||||
|
|
||||||
|
- Check 'content' in tool role -> convert to dict (if not) -> stringify
|
||||||
|
|
||||||
|
Done for azure_ai/cohere calls to handle results of a tool call
|
||||||
|
"""
|
||||||
|
|
||||||
|
for m in messages:
|
||||||
|
if m["role"] == "tool" and isinstance(m["content"], str):
|
||||||
|
# check if content is a valid json object
|
||||||
|
try:
|
||||||
|
json.loads(m["content"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
m["content"] = json.dumps({"result": m["content"]})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
###### AMAZON BEDROCK #######
|
###### AMAZON BEDROCK #######
|
||||||
|
|
||||||
from litellm.types.llms.bedrock import ContentBlock as BedrockContentBlock
|
from litellm.types.llms.bedrock import ContentBlock as BedrockContentBlock
|
||||||
|
|
|
@ -113,6 +113,7 @@ from .llms.prompt_templates.factory import (
|
||||||
function_call_prompt,
|
function_call_prompt,
|
||||||
map_system_message_pt,
|
map_system_message_pt,
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
|
stringify_json_tool_call_content,
|
||||||
)
|
)
|
||||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
|
@ -1114,6 +1115,73 @@ def completion(
|
||||||
"api_base": api_base,
|
"api_base": api_base,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "azure_ai":
|
||||||
|
api_base = (
|
||||||
|
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
||||||
|
or litellm.api_base
|
||||||
|
or get_secret("AZURE_AI_API_BASE")
|
||||||
|
)
|
||||||
|
# set API KEY
|
||||||
|
api_key = (
|
||||||
|
api_key
|
||||||
|
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
|
||||||
|
or litellm.openai_key
|
||||||
|
or get_secret("AZURE_AI_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
|
## LOAD CONFIG - if set
|
||||||
|
config = litellm.OpenAIConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
## FOR COHERE
|
||||||
|
if "command-r" in model: # make sure tool call in messages are str
|
||||||
|
messages = stringify_json_tool_call_content(messages=messages)
|
||||||
|
|
||||||
|
## COMPLETION CALL
|
||||||
|
try:
|
||||||
|
response = openai_chat_completions.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
headers=headers,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
acompletion=acompletion,
|
||||||
|
logging_obj=logging,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
timeout=timeout, # type: ignore
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
client=client, # pass AsyncOpenAI, OpenAI client
|
||||||
|
organization=organization,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
## LOGGING - log the original exception returned
|
||||||
|
logging.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=str(e),
|
||||||
|
additional_args={"headers": headers},
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if optional_params.get("stream", False):
|
||||||
|
## LOGGING
|
||||||
|
logging.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response,
|
||||||
|
additional_args={"headers": headers},
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
custom_llm_provider == "text-completion-openai"
|
custom_llm_provider == "text-completion-openai"
|
||||||
or "ft:babbage-002" in model
|
or "ft:babbage-002" in model
|
||||||
|
|
|
@ -410,17 +410,21 @@ def test_completion_claude_3_function_call(model):
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True])
|
@pytest.mark.parametrize("sync_mode", [True])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model, api_key, api_base",
|
||||||
[
|
[
|
||||||
"gpt-3.5-turbo",
|
("gpt-3.5-turbo", None, None),
|
||||||
"claude-3-opus-20240229",
|
("claude-3-opus-20240229", None, None),
|
||||||
"command-r",
|
("command-r", None, None),
|
||||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
("anthropic.claude-3-sonnet-20240229-v1:0", None, None),
|
||||||
# "azure_ai/command-r-plus"
|
(
|
||||||
|
"azure_ai/command-r-plus",
|
||||||
|
os.getenv("AZURE_COHERE_API_KEY"),
|
||||||
|
os.getenv("AZURE_COHERE_API_BASE"),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_model_function_invoke(model, sync_mode):
|
async def test_model_function_invoke(model, sync_mode, api_key, api_base):
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
@ -483,6 +487,8 @@ async def test_model_function_invoke(model, sync_mode):
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"tools": tools,
|
"tools": tools,
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_base": api_base,
|
||||||
}
|
}
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
response = litellm.completion(**data)
|
response = litellm.completion(**data)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue