Merge branch 'main' into bedrock-llama3.1-405b

This commit is contained in:
Krish Dholakia 2024-07-25 19:29:10 -07:00 committed by GitHub
commit a5cea7929d
22 changed files with 888 additions and 159 deletions

View file

@ -330,6 +330,18 @@ class Rules:
####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def custom_llm_setup():
"""
Add custom_llm provider to provider list
"""
for custom_llm in litellm.custom_provider_map:
if custom_llm["provider"] not in litellm.provider_list:
litellm.provider_list.append(custom_llm["provider"])
if custom_llm["provider"] not in litellm._custom_providers:
litellm._custom_providers.append(custom_llm["provider"])
def function_setup(
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -341,6 +353,10 @@ def function_setup(
try:
global callback_list, add_breadcrumb, user_logger_fn, Logging
## CUSTOM LLM SETUP ##
custom_llm_setup()
## LOGGING SETUP
function_id = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0:
@ -9247,7 +9263,10 @@ class CustomStreamWrapper:
try:
# return this for all models
completion_obj = {"content": ""}
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
if self.custom_llm_provider and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
):
from litellm.types.utils import GenericStreamingChunk as GChunk
if self.received_finish_reason is not None:
@ -10114,6 +10133,7 @@ class CustomStreamWrapper:
try:
if self.completion_stream is None:
await self.fetch_stream()
if (
self.custom_llm_provider == "openai"
or self.custom_llm_provider == "azure"
@ -10138,6 +10158,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "triton"
or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
or self.custom_llm_provider in litellm._custom_providers
):
async for chunk in self.completion_stream:
print_verbose(f"value of async chunk: {chunk}")
@ -10966,3 +10987,8 @@ class ModelResponseIterator:
raise StopAsyncIteration
self.is_done = True
return self.model_response
class CustomModelResponseIterator(Iterable):
def __init__(self) -> None:
super().__init__()