Merge branch 'main' into litellm_proxy_support_all_providers

This commit is contained in:
Ishaan Jaff 2024-07-25 20:15:37 -07:00 committed by GitHub
commit 079a41fbe1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1329 additions and 350 deletions

View file

@ -330,6 +330,18 @@ class Rules:
####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def custom_llm_setup():
"""
Add custom_llm provider to provider list
"""
for custom_llm in litellm.custom_provider_map:
if custom_llm["provider"] not in litellm.provider_list:
litellm.provider_list.append(custom_llm["provider"])
if custom_llm["provider"] not in litellm._custom_providers:
litellm._custom_providers.append(custom_llm["provider"])
def function_setup(
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -341,6 +353,10 @@ def function_setup(
try:
global callback_list, add_breadcrumb, user_logger_fn, Logging
## CUSTOM LLM SETUP ##
custom_llm_setup()
## LOGGING SETUP
function_id = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0:
@ -3121,7 +3137,19 @@ def get_optional_params(
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
if "ai21" in model:
if model in litellm.BEDROCK_CONVERSE_MODELS:
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif "ai21" in model:
_check_valid_arg(supported_params=supported_params)
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
@ -3143,17 +3171,6 @@ def get_optional_params(
optional_params=optional_params,
)
)
elif model in litellm.BEDROCK_CONVERSE_MODELS:
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
@ -8825,21 +8842,6 @@ class CustomStreamWrapper:
if str_line.choices[0].finish_reason:
is_finished = True
finish_reason = str_line.choices[0].finish_reason
if finish_reason == "content_filter":
if hasattr(str_line.choices[0], "content_filter_result"):
error_message = json.dumps(
str_line.choices[0].content_filter_result
)
else:
error_message = "{} Response={}".format(
self.custom_llm_provider, str(dict(str_line))
)
raise litellm.ContentPolicyViolationError(
message=error_message,
llm_provider=self.custom_llm_provider,
model=self.model,
)
# checking for logprobs
if (
@ -9248,7 +9250,10 @@ class CustomStreamWrapper:
try:
# return this for all models
completion_obj = {"content": ""}
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
if self.custom_llm_provider and (
self.custom_llm_provider == "anthropic"
or self.custom_llm_provider in litellm._custom_providers
):
from litellm.types.utils import GenericStreamingChunk as GChunk
if self.received_finish_reason is not None:
@ -10115,6 +10120,7 @@ class CustomStreamWrapper:
try:
if self.completion_stream is None:
await self.fetch_stream()
if (
self.custom_llm_provider == "openai"
or self.custom_llm_provider == "azure"
@ -10139,6 +10145,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "triton"
or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
or self.custom_llm_provider in litellm._custom_providers
):
async for chunk in self.completion_stream:
print_verbose(f"value of async chunk: {chunk}")
@ -10967,3 +10974,8 @@ class ModelResponseIterator:
raise StopAsyncIteration
self.is_done = True
return self.model_response
class CustomModelResponseIterator(Iterable):
def __init__(self) -> None:
super().__init__()