LiteLLM Minor Fixes & Improvements (10/16/2024) (#6265)

* fix(caching_handler.py): handle positional arguments in add cache logic

Fixes https://github.com/BerriAI/litellm/issues/6264

* feat(litellm_pre_call_utils.py): allow forwarding openai org id to backend client

https://github.com/BerriAI/litellm/issues/6237

* docs(configs.md): add 'forward_openai_org_id' to docs

* fix(proxy_server.py): return model info if user_model is set

Fixes https://github.com/BerriAI/litellm/issues/6233

* fix(hosted_vllm/chat/transformation.py): don't set tools unless non-none

* fix(openai.py): improve debug log for openai 'str' error

Addresses https://github.com/BerriAI/litellm/issues/6272

* fix(proxy_server.py): fix linting error

* fix(proxy_server.py): fix linting errors

* test: skip WIP test

* docs(openai.md): add docs on passing openai org id from client to openai
This commit is contained in:
Krish Dholakia 2024-10-16 22:16:23 -07:00 committed by GitHub
parent 43878bd2a0
commit 38a9a106d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 371 additions and 47 deletions

View file

@ -16,6 +16,7 @@ In each method it will call the appropriate method from caching.py
import asyncio
import datetime
import inspect
import threading
from typing import (
TYPE_CHECKING,
@ -632,7 +633,7 @@ class LLMCachingHandler:
logging_obj=logging_obj,
)
async def _async_set_cache(
async def async_set_cache(
self,
result: Any,
original_function: Callable,
@ -653,7 +654,7 @@ class LLMCachingHandler:
Raises:
None
"""
args = args or ()
kwargs.update(convert_args_to_kwargs(result, original_function, kwargs, args))
if litellm.cache is None:
return
# [OPTIONAL] ADD TO CACHE
@ -675,24 +676,24 @@ class LLMCachingHandler:
) # s3 doesn't support bulk writing. Exclude.
):
asyncio.create_task(
litellm.cache.async_add_cache_pipeline(result, *args, **kwargs)
litellm.cache.async_add_cache_pipeline(result, **kwargs)
)
elif isinstance(litellm.cache.cache, S3Cache):
threading.Thread(
target=litellm.cache.add_cache,
args=(result,) + args,
args=(result,),
kwargs=kwargs,
).start()
else:
asyncio.create_task(
litellm.cache.async_add_cache(result.json(), *args, **kwargs)
litellm.cache.async_add_cache(
result.model_dump_json(), **kwargs
)
)
else:
asyncio.create_task(
litellm.cache.async_add_cache(result, *args, **kwargs)
)
asyncio.create_task(litellm.cache.async_add_cache(result, **kwargs))
def _sync_set_cache(
def sync_set_cache(
self,
result: Any,
kwargs: Dict[str, Any],
@ -701,14 +702,16 @@ class LLMCachingHandler:
"""
Sync internal method to add the result to the cache
"""
kwargs.update(
convert_args_to_kwargs(result, self.original_function, kwargs, args)
)
if litellm.cache is None:
return
args = args or ()
if self._should_store_result_in_cache(
original_function=self.original_function, kwargs=kwargs
):
litellm.cache.add_cache(result, *args, **kwargs)
litellm.cache.add_cache(result, **kwargs)
return
@ -772,7 +775,7 @@ class LLMCachingHandler:
# if a complete_streaming_response is assembled, add it to the cache
if complete_streaming_response is not None:
await self._async_set_cache(
await self.async_set_cache(
result=complete_streaming_response,
original_function=self.original_function,
kwargs=self.request_kwargs,
@ -795,7 +798,7 @@ class LLMCachingHandler:
# if a complete_streaming_response is assembled, add it to the cache
if complete_streaming_response is not None:
self._sync_set_cache(
self.sync_set_cache(
result=complete_streaming_response,
kwargs=self.request_kwargs,
)
@ -849,3 +852,26 @@ class LLMCachingHandler:
additional_args=None,
stream=kwargs.get("stream", False),
)
def convert_args_to_kwargs(
result: Any,
original_function: Callable,
kwargs: Dict[str, Any],
args: Optional[Tuple[Any, ...]] = None,
) -> Dict[str, Any]:
# Get the signature of the original function
signature = inspect.signature(original_function)
# Get parameter names in the order they appear in the original function
param_names = list(signature.parameters.keys())
# Create a mapping of positional arguments to parameter names
args_to_kwargs = {}
if args:
for index, arg in enumerate(args):
if index < len(param_names):
param_name = param_names[index]
args_to_kwargs[param_name] = arg
return args_to_kwargs