mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 34s
* fix(utils.py): initial commit fixing custom cost tracking refactors out provider specific model info from `get_model_info` - this was causing custom costs to be registered incorrectly * fix(utils.py): cleanup `_supports_factory` to check provider info, if model info is None some providers support features like vision across all models * fix(utils.py): refactor to use _supports_factory * test: update testing * fix: fix linting errors * test: fix testing
107 lines
3.3 KiB
Python
107 lines
3.3 KiB
Python
import copy
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Optional, Type, Union
|
|
|
|
from openai.lib import _parsing, _pydantic
|
|
from pydantic import BaseModel
|
|
|
|
from litellm.types.utils import ProviderSpecificModelInfo
|
|
|
|
|
|
class BaseLLMModelInfo(ABC):
|
|
def get_provider_info(
|
|
self,
|
|
model: str,
|
|
) -> Optional[ProviderSpecificModelInfo]:
|
|
return None
|
|
|
|
@abstractmethod
|
|
def get_models(self) -> List[str]:
|
|
pass
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
|
pass
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
|
pass
|
|
|
|
|
|
def _dict_to_response_format_helper(
|
|
response_format: dict, ref_template: Optional[str] = None
|
|
) -> dict:
|
|
if ref_template is not None and response_format.get("type") == "json_schema":
|
|
# Deep copy to avoid modifying original
|
|
modified_format = copy.deepcopy(response_format)
|
|
schema = modified_format["json_schema"]["schema"]
|
|
|
|
# Update all $ref values in the schema
|
|
def update_refs(schema):
|
|
stack = [(schema, [])]
|
|
visited = set()
|
|
|
|
while stack:
|
|
obj, path = stack.pop()
|
|
obj_id = id(obj)
|
|
|
|
if obj_id in visited:
|
|
continue
|
|
visited.add(obj_id)
|
|
|
|
if isinstance(obj, dict):
|
|
if "$ref" in obj:
|
|
ref_path = obj["$ref"]
|
|
model_name = ref_path.split("/")[-1]
|
|
obj["$ref"] = ref_template.format(model=model_name)
|
|
|
|
for k, v in obj.items():
|
|
if isinstance(v, (dict, list)):
|
|
stack.append((v, path + [k]))
|
|
|
|
elif isinstance(obj, list):
|
|
for i, item in enumerate(obj):
|
|
if isinstance(item, (dict, list)):
|
|
stack.append((item, path + [i]))
|
|
|
|
update_refs(schema)
|
|
return modified_format
|
|
return response_format
|
|
|
|
|
|
def type_to_response_format_param(
|
|
response_format: Optional[Union[Type[BaseModel], dict]],
|
|
ref_template: Optional[str] = None,
|
|
) -> Optional[dict]:
|
|
"""
|
|
Re-implementation of openai's 'type_to_response_format_param' function
|
|
|
|
Used for converting pydantic object to api schema.
|
|
"""
|
|
if response_format is None:
|
|
return None
|
|
|
|
if isinstance(response_format, dict):
|
|
return _dict_to_response_format_helper(response_format, ref_template)
|
|
|
|
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
|
|
# a safe default behaviour but we know that at this point the `response_format`
|
|
# can only be a `type`
|
|
if not _parsing._completions.is_basemodel_type(response_format):
|
|
raise TypeError(f"Unsupported response_format type - {response_format}")
|
|
|
|
if ref_template is not None:
|
|
schema = response_format.model_json_schema(ref_template=ref_template)
|
|
else:
|
|
schema = _pydantic.to_strict_json_schema(response_format)
|
|
|
|
return {
|
|
"type": "json_schema",
|
|
"json_schema": {
|
|
"schema": schema,
|
|
"name": response_format.__name__,
|
|
"strict": True,
|
|
},
|
|
}
|