diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index e3f0ff4e8..1b380594f 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -1903,6 +1903,26 @@ def azure_text_pt(messages: list): return prompt +###### AZURE AI ####### +def stringify_json_tool_call_content(messages: List) -> List: + """ + + - Check 'content' in tool role -> convert to dict (if not) -> stringify + + Done for azure_ai/cohere calls to handle results of a tool call + """ + + for m in messages: + if m["role"] == "tool" and isinstance(m["content"], str): + # check if content is a valid json object + try: + json.loads(m["content"]) + except json.JSONDecodeError: + m["content"] = json.dumps({"result": m["content"]}) + + return messages + + ###### AMAZON BEDROCK ####### from litellm.types.llms.bedrock import ContentBlock as BedrockContentBlock diff --git a/litellm/main.py b/litellm/main.py index 9b42b0d07..1da31abd7 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -113,6 +113,7 @@ from .llms.prompt_templates.factory import ( function_call_prompt, map_system_message_pt, prompt_factory, + stringify_json_tool_call_content, ) from .llms.text_completion_codestral import CodestralTextCompletion from .llms.triton import TritonChatCompletion @@ -1114,6 +1115,73 @@ def completion( "api_base": api_base, }, ) + elif custom_llm_provider == "azure_ai": + api_base = ( + api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or get_secret("AZURE_AI_API_BASE") + ) + # set API KEY + api_key = ( + api_key + or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or get_secret("AZURE_AI_API_KEY") + ) + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.OpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## FOR COHERE + if "command-r" in model: # make sure tool call in messages are str + messages = stringify_json_tool_call_content(messages=messages) + + ## COMPLETION CALL + try: + response = openai_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + custom_prompt_dict=custom_prompt_dict, + client=client, # pass AsyncOpenAI, OpenAI client + organization=organization, + custom_llm_provider=custom_llm_provider, + ) + except Exception as e: + ## LOGGING - log the original exception returned + logging.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"headers": headers}, + ) + raise e + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={"headers": headers}, + ) elif ( custom_llm_provider == "text-completion-openai" or "ft:babbage-002" in model diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 72567e05d..40c15d06d 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -410,17 +410,21 @@ def test_completion_claude_3_function_call(model): @pytest.mark.parametrize("sync_mode", [True]) @pytest.mark.parametrize( - "model", + "model, api_key, api_base", [ - "gpt-3.5-turbo", - "claude-3-opus-20240229", - "command-r", - "anthropic.claude-3-sonnet-20240229-v1:0", - # "azure_ai/command-r-plus" + ("gpt-3.5-turbo", None, None), + ("claude-3-opus-20240229", None, None), + ("command-r", None, None), + ("anthropic.claude-3-sonnet-20240229-v1:0", None, None), + ( + "azure_ai/command-r-plus", + os.getenv("AZURE_COHERE_API_KEY"), + os.getenv("AZURE_COHERE_API_BASE"), + ), ], ) @pytest.mark.asyncio -async def test_model_function_invoke(model, sync_mode): +async def test_model_function_invoke(model, sync_mode, api_key, api_base): try: litellm.set_verbose = True @@ -445,7 +449,7 @@ async def test_model_function_invoke(model, sync_mode): "index": 0, "function": { "name": "get_weather", - "arguments": '{"location":"San Francisco, CA"}', + "arguments": '{"location": "San Francisco, CA"}', }, } ], @@ -483,6 +487,8 @@ async def test_model_function_invoke(model, sync_mode): "model": model, "messages": messages, "tools": tools, + "api_key": api_key, + "api_base": api_base, } if sync_mode: response = litellm.completion(**data)