forked from phoenix/litellm-mirror
(feat) new function_to_dict litellm.util
This commit is contained in:
parent
02e97acefa
commit
7848f1b5b7
3 changed files with 169 additions and 3 deletions
|
@ -35,6 +35,7 @@ jobs:
|
||||||
pip install "boto3>=1.28.57"
|
pip install "boto3>=1.28.57"
|
||||||
pip install appdirs
|
pip install appdirs
|
||||||
pip install langchain
|
pip install langchain
|
||||||
|
pip install numpydoc
|
||||||
- save_cache:
|
- save_cache:
|
||||||
paths:
|
paths:
|
||||||
- ./venv
|
- ./venv
|
||||||
|
|
|
@ -10,7 +10,7 @@ sys.path.insert(
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import trim_messages, get_token_count, get_valid_models, check_valid_key, validate_environment
|
from litellm.utils import trim_messages, get_token_count, get_valid_models, check_valid_key, validate_environment, function_to_dict
|
||||||
|
|
||||||
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils'
|
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils'
|
||||||
|
|
||||||
|
@ -101,4 +101,55 @@ def test_validate_environment_empty_model():
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise Exception()
|
raise Exception()
|
||||||
|
|
||||||
# test_validate_environment_empty_model()
|
# test_validate_environment_empty_model()
|
||||||
|
|
||||||
|
def test_function_to_dict():
|
||||||
|
print("testing function to dict for get current weather")
|
||||||
|
def get_current_weather(location: str, unit: str):
|
||||||
|
"""Get the current weather in a given location
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
location : str
|
||||||
|
The city and state, e.g. San Francisco, CA
|
||||||
|
unit : {'celsius', 'fahrenheit'}
|
||||||
|
Temperature unit
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
a sentence indicating the weather
|
||||||
|
"""
|
||||||
|
if location == "Boston, MA":
|
||||||
|
return "The weather is 12F"
|
||||||
|
function_json = litellm.utils.function_to_dict(get_current_weather)
|
||||||
|
print(function_json)
|
||||||
|
|
||||||
|
expected_output = {
|
||||||
|
'name': 'get_current_weather',
|
||||||
|
'description': 'Get the current weather in a given location',
|
||||||
|
'parameters': {
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {
|
||||||
|
'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'},
|
||||||
|
'unit': {'type': 'string', 'description': 'Temperature unit', 'enum': "['fahrenheit', 'celsius']"}
|
||||||
|
},
|
||||||
|
'required': ['location', 'unit']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
print(expected_output)
|
||||||
|
|
||||||
|
assert function_json['name'] == expected_output["name"]
|
||||||
|
assert function_json["description"] == expected_output["description"]
|
||||||
|
assert function_json["parameters"]["type"] == expected_output["parameters"]["type"]
|
||||||
|
assert function_json["parameters"]["properties"]["location"] == expected_output["parameters"]["properties"]["location"]
|
||||||
|
|
||||||
|
# the enum can change it can be - which is why we don't assert on unit
|
||||||
|
# {'type': 'string', 'description': 'Temperature unit', 'enum': "['fahrenheit', 'celsius']"}
|
||||||
|
# {'type': 'string', 'description': 'Temperature unit', 'enum': "['celsius', 'fahrenheit']"}
|
||||||
|
|
||||||
|
assert function_json["parameters"]["required"] == expected_output["parameters"]["required"]
|
||||||
|
|
||||||
|
print("passed")
|
||||||
|
# test_function_to_dict()
|
||||||
|
|
||||||
|
|
116
litellm/utils.py
116
litellm/utils.py
|
@ -1615,7 +1615,121 @@ def get_max_tokens(model: str):
|
||||||
return litellm.model_cost[model]
|
return litellm.model_cost[model]
|
||||||
except:
|
except:
|
||||||
raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json")
|
raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json")
|
||||||
|
|
||||||
|
def json_schema_type(python_type_name: str):
|
||||||
|
"""Converts standard python types to json schema types
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
python_type_name : str
|
||||||
|
__name__ of type
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
a standard JSON schema type, "string" if not recognized.
|
||||||
|
"""
|
||||||
|
python_to_json_schema_types = {
|
||||||
|
str.__name__: "string",
|
||||||
|
int.__name__: "integer",
|
||||||
|
float.__name__: "number",
|
||||||
|
bool.__name__: "boolean",
|
||||||
|
list.__name__: "array",
|
||||||
|
dict.__name__: "object",
|
||||||
|
"NoneType": "null",
|
||||||
|
}
|
||||||
|
|
||||||
|
return python_to_json_schema_types.get(python_type_name, "string")
|
||||||
|
|
||||||
|
|
||||||
|
def function_to_dict(input_function): # noqa: C901
|
||||||
|
"""Using type hints and numpy-styled docstring,
|
||||||
|
produce a dictionnary usable for OpenAI function calling
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_function : function
|
||||||
|
A function with a numpy-style docstring
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dictionnary
|
||||||
|
A dictionnary to add to the list passed to `functions` parameter of `litellm.completion`
|
||||||
|
"""
|
||||||
|
# Get function name and docstring
|
||||||
|
try:
|
||||||
|
import inspect
|
||||||
|
from numpydoc.docscrape import NumpyDocString
|
||||||
|
from ast import literal_eval
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
name = input_function.__name__
|
||||||
|
docstring = inspect.getdoc(input_function)
|
||||||
|
numpydoc = NumpyDocString(docstring)
|
||||||
|
description = "\n".join([s.strip() for s in numpydoc["Summary"]])
|
||||||
|
|
||||||
|
# Get function parameters and their types from annotations and docstring
|
||||||
|
parameters = {}
|
||||||
|
required_params = []
|
||||||
|
param_info = inspect.signature(input_function).parameters
|
||||||
|
|
||||||
|
for param_name, param in param_info.items():
|
||||||
|
if hasattr(param, "annotation"):
|
||||||
|
param_type = json_schema_type(param.annotation.__name__)
|
||||||
|
else:
|
||||||
|
param_type = None
|
||||||
|
param_description = None
|
||||||
|
param_enum = None
|
||||||
|
|
||||||
|
# Try to extract param description from docstring using numpydoc
|
||||||
|
for param_data in numpydoc["Parameters"]:
|
||||||
|
if param_data.name == param_name:
|
||||||
|
if hasattr(param_data, "type"):
|
||||||
|
# replace type from docstring rather than annotation
|
||||||
|
param_type = param_data.type
|
||||||
|
if "optional" in param_type:
|
||||||
|
param_type = param_type.split(",")[0]
|
||||||
|
elif "{" in param_type:
|
||||||
|
# may represent a set of acceptable values
|
||||||
|
# translating as enum for function calling
|
||||||
|
try:
|
||||||
|
param_enum = str(list(literal_eval(param_type)))
|
||||||
|
param_type = "string"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
param_type = json_schema_type(param_type)
|
||||||
|
param_description = "\n".join([s.strip() for s in param_data.desc])
|
||||||
|
|
||||||
|
param_dict = {
|
||||||
|
"type": param_type,
|
||||||
|
"description": param_description,
|
||||||
|
"enum": param_enum,
|
||||||
|
}
|
||||||
|
|
||||||
|
parameters[param_name] = dict(
|
||||||
|
[(k, v) for k, v in param_dict.items() if isinstance(v, str)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the parameter has no default value (i.e., it's required)
|
||||||
|
if param.default == param.empty:
|
||||||
|
required_params.append(param_name)
|
||||||
|
|
||||||
|
# Create the dictionary
|
||||||
|
result = {
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": parameters,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add "required" key if there are required parameters
|
||||||
|
if required_params:
|
||||||
|
result["parameters"]["required"] = required_params
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def load_test_model(
|
def load_test_model(
|
||||||
model: str,
|
model: str,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue