(v0) tool calling

This commit is contained in:
ishaan-jaff 2024-03-12 12:35:52 -07:00
parent 5172fb1de9
commit d136238f6f
3 changed files with 118 additions and 0 deletions

View file

@ -22,6 +22,12 @@ class CohereError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def construct_cohere_tool(tools=None):
if tools is None:
tools = []
return {"tools": tools}
class CohereConfig: class CohereConfig:
""" """
Reference: https://docs.cohere.com/reference/generate Reference: https://docs.cohere.com/reference/generate
@ -145,6 +151,14 @@ def completion(
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
tool_calling_system_prompt = construct_cohere_tool(
tools=optional_params["tools"]
)
optional_params["tools"] = tool_calling_system_prompt
data = { data = {
"model": model, "model": model,
"prompt": prompt, "prompt": prompt,

View file

@ -0,0 +1,103 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
litellm.num_retries = 3
# FYI - cohere_chat looks quite unstable, even when testing locally
def test_chat_completion_cohere():
try:
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="cohere_chat/command-r",
messages=messages,
max_tokens=10,
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_chat_completion_cohere_stream():
try:
litellm.set_verbose = False
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="cohere_chat/command-r",
messages=messages,
max_tokens=10,
stream=True,
)
print(response)
for chunk in response:
print(chunk)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_chat_completion_cohere_tool_calling():
try:
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You're a good bot"},
{
"role": "user",
"content": "Hey",
},
]
response = completion(
model="cohere_chat/command-r",
messages=messages,
max_tokens=10,
tools=[
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
],
)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -4269,6 +4269,7 @@ def get_optional_params(
and custom_llm_provider != "together_ai" and custom_llm_provider != "together_ai"
and custom_llm_provider != "mistral" and custom_llm_provider != "mistral"
and custom_llm_provider != "anthropic" and custom_llm_provider != "anthropic"
and custom_llm_provider != "cohere_chat"
and custom_llm_provider != "bedrock" and custom_llm_provider != "bedrock"
and custom_llm_provider != "ollama_chat" and custom_llm_provider != "ollama_chat"
): ):