Merge pull request #3029 from BerriAI/litellm_add_groq_tool_calling

Feat - add groq tool calling + testing
This commit is contained in:
Ishaan Jaff 2024-04-15 08:49:52 -07:00 committed by GitHub
commit a15f61cc05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 143 additions and 8 deletions

View file

@ -714,7 +714,8 @@
"input_cost_per_token": 0.00000070, "input_cost_per_token": 0.00000070,
"output_cost_per_token": 0.00000080, "output_cost_per_token": 0.00000080,
"litellm_provider": "groq", "litellm_provider": "groq",
"mode": "chat" "mode": "chat",
"supports_function_calling": true
}, },
"groq/mixtral-8x7b-32768": { "groq/mixtral-8x7b-32768": {
"max_tokens": 32768, "max_tokens": 32768,
@ -723,7 +724,8 @@
"input_cost_per_token": 0.00000027, "input_cost_per_token": 0.00000027,
"output_cost_per_token": 0.00000027, "output_cost_per_token": 0.00000027,
"litellm_provider": "groq", "litellm_provider": "groq",
"mode": "chat" "mode": "chat",
"supports_function_calling": true
}, },
"groq/gemma-7b-it": { "groq/gemma-7b-it": {
"max_tokens": 8192, "max_tokens": 8192,
@ -732,7 +734,8 @@
"input_cost_per_token": 0.00000010, "input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000010, "output_cost_per_token": 0.00000010,
"litellm_provider": "groq", "litellm_provider": "groq",
"mode": "chat" "mode": "chat",
"supports_function_calling": true
}, },
"claude-instant-1.2": { "claude-instant-1.2": {
"max_tokens": 8191, "max_tokens": 8191,

View file

@ -219,3 +219,94 @@ def test_parallel_function_call_stream():
# test_parallel_function_call_stream() # test_parallel_function_call_stream()
def test_groq_parallel_function_call():
litellm.set_verbose = True
try:
# Step 1: send the conversation and available functions to the model
messages = [
{
"role": "system",
"content": "You are a function calling LLM that uses the data extracted from get_current_weather to answer questions about the weather in San Francisco.",
},
{
"role": "user",
"content": "What's the weather like in San Francisco?",
},
]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
response = litellm.completion(
model="groq/llama2-70b-4096",
messages=messages,
tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit
)
print("Response\n", response)
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
assert isinstance(response.choices[0].message.tool_calls[0].function.name, str)
assert isinstance(
response.choices[0].message.tool_calls[0].function.arguments, str
)
print("length of tool calls", len(tool_calls))
# Step 2: check if the model wanted to call a function
if tool_calls:
# Step 3: call the function
# Note: the JSON response may not always be valid; be sure to handle errors
available_functions = {
"get_current_weather": get_current_weather,
} # only one function in this example, but you can have multiple
messages.append(
response_message
) # extend conversation with assistant's reply
print("Response message\n", response_message)
# Step 4: send the info for each function call and function response to the model
for tool_call in tool_calls:
function_name = tool_call.function.name
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(
location=function_args.get("location"),
unit=function_args.get("unit"),
)
messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": function_response,
}
) # extend conversation with function response
print(f"messages: {messages}")
second_response = litellm.completion(
model="groq/llama2-70b-4096", messages=messages
) # get a new response from the model where it can see the function response
print("second response\n", second_response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -223,6 +223,7 @@ def test_validate_environment_ollama():
assert kv["keys_in_environment"] assert kv["keys_in_environment"]
assert kv["missing_keys"] == [] assert kv["missing_keys"] == []
@mock.patch.dict(os.environ, {}, clear=True) @mock.patch.dict(os.environ, {}, clear=True)
def test_validate_environment_ollama_failed(): def test_validate_environment_ollama_failed():
for provider in ["ollama", "ollama_chat"]: for provider in ["ollama", "ollama_chat"]:
@ -230,6 +231,7 @@ def test_validate_environment_ollama_failed():
assert not kv["keys_in_environment"] assert not kv["keys_in_environment"]
assert kv["missing_keys"] == ["OLLAMA_API_BASE"] assert kv["missing_keys"] == ["OLLAMA_API_BASE"]
def test_function_to_dict(): def test_function_to_dict():
print("testing function to dict for get current weather") print("testing function to dict for get current weather")
@ -338,6 +340,7 @@ def test_supports_function_calling():
assert ( assert (
litellm.supports_function_calling(model="azure/gpt-4-1106-preview") == True litellm.supports_function_calling(model="azure/gpt-4-1106-preview") == True
) )
assert litellm.supports_function_calling(model="groq/gemma-7b-it") == True
assert ( assert (
litellm.supports_function_calling(model="anthropic.claude-instant-v1") litellm.supports_function_calling(model="anthropic.claude-instant-v1")
== False == False

View file

@ -4523,6 +4523,7 @@ def get_optional_params(
and custom_llm_provider != "vertex_ai" and custom_llm_provider != "vertex_ai"
and custom_llm_provider != "anyscale" and custom_llm_provider != "anyscale"
and custom_llm_provider != "together_ai" and custom_llm_provider != "together_ai"
and custom_llm_provider != "groq"
and custom_llm_provider != "mistral" and custom_llm_provider != "mistral"
and custom_llm_provider != "anthropic" and custom_llm_provider != "anthropic"
and custom_llm_provider != "cohere_chat" and custom_llm_provider != "cohere_chat"
@ -5222,6 +5223,29 @@ def get_optional_params(
optional_params["extra_body"] = ( optional_params["extra_body"] = (
extra_body # openai client supports `extra_body` param extra_body # openai client supports `extra_body` param
) )
elif custom_llm_provider == "groq":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if temperature is not None:
optional_params["temperature"] = temperature
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if top_p is not None:
optional_params["top_p"] = top_p
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
optional_params["stop"] = stop
if tools is not None:
optional_params["tools"] = tools
if tool_choice is not None:
optional_params["tool_choice"] = tool_choice
if response_format is not None:
optional_params["response_format"] = tool_choice
elif custom_llm_provider == "openrouter": elif custom_llm_provider == "openrouter":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -5426,6 +5450,17 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"tools", "tools",
"tool_choice", "tool_choice",
] ]
elif custom_llm_provider == "groq":
return [
"temperature",
"max_tokens",
"top_p",
"stream",
"stop",
"tools",
"tool_choice",
"response_format",
]
elif custom_llm_provider == "cohere": elif custom_llm_provider == "cohere":
return [ return [
"stream", "stream",

View file

@ -714,7 +714,8 @@
"input_cost_per_token": 0.00000070, "input_cost_per_token": 0.00000070,
"output_cost_per_token": 0.00000080, "output_cost_per_token": 0.00000080,
"litellm_provider": "groq", "litellm_provider": "groq",
"mode": "chat" "mode": "chat",
"supports_function_calling": true
}, },
"groq/mixtral-8x7b-32768": { "groq/mixtral-8x7b-32768": {
"max_tokens": 32768, "max_tokens": 32768,
@ -723,7 +724,8 @@
"input_cost_per_token": 0.00000027, "input_cost_per_token": 0.00000027,
"output_cost_per_token": 0.00000027, "output_cost_per_token": 0.00000027,
"litellm_provider": "groq", "litellm_provider": "groq",
"mode": "chat" "mode": "chat",
"supports_function_calling": true
}, },
"groq/gemma-7b-it": { "groq/gemma-7b-it": {
"max_tokens": 8192, "max_tokens": 8192,
@ -732,7 +734,8 @@
"input_cost_per_token": 0.00000010, "input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000010, "output_cost_per_token": 0.00000010,
"litellm_provider": "groq", "litellm_provider": "groq",
"mode": "chat" "mode": "chat",
"supports_function_calling": true
}, },
"claude-instant-1.2": { "claude-instant-1.2": {
"max_tokens": 8191, "max_tokens": 8191,