mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge branch 'main' into litellm_proxy_support_all_providers
This commit is contained in:
commit
079a41fbe1
33 changed files with 1329 additions and 350 deletions
|
@ -330,6 +330,18 @@ class Rules:
|
|||
|
||||
####### CLIENT ###################
|
||||
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
||||
def custom_llm_setup():
|
||||
"""
|
||||
Add custom_llm provider to provider list
|
||||
"""
|
||||
for custom_llm in litellm.custom_provider_map:
|
||||
if custom_llm["provider"] not in litellm.provider_list:
|
||||
litellm.provider_list.append(custom_llm["provider"])
|
||||
|
||||
if custom_llm["provider"] not in litellm._custom_providers:
|
||||
litellm._custom_providers.append(custom_llm["provider"])
|
||||
|
||||
|
||||
def function_setup(
|
||||
original_function: str, rules_obj, start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
|
@ -341,6 +353,10 @@ def function_setup(
|
|||
try:
|
||||
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
||||
|
||||
## CUSTOM LLM SETUP ##
|
||||
custom_llm_setup()
|
||||
|
||||
## LOGGING SETUP
|
||||
function_id = kwargs["id"] if "id" in kwargs else None
|
||||
|
||||
if len(litellm.callbacks) > 0:
|
||||
|
@ -3121,7 +3137,19 @@ def get_optional_params(
|
|||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
if "ai21" in model:
|
||||
if model in litellm.BEDROCK_CONVERSE_MODELS:
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif "ai21" in model:
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
|
||||
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||
|
@ -3143,17 +3171,6 @@ def get_optional_params(
|
|||
optional_params=optional_params,
|
||||
)
|
||||
)
|
||||
elif model in litellm.BEDROCK_CONVERSE_MODELS:
|
||||
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
else:
|
||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
|
@ -8825,21 +8842,6 @@ class CustomStreamWrapper:
|
|||
if str_line.choices[0].finish_reason:
|
||||
is_finished = True
|
||||
finish_reason = str_line.choices[0].finish_reason
|
||||
if finish_reason == "content_filter":
|
||||
if hasattr(str_line.choices[0], "content_filter_result"):
|
||||
error_message = json.dumps(
|
||||
str_line.choices[0].content_filter_result
|
||||
)
|
||||
else:
|
||||
error_message = "{} Response={}".format(
|
||||
self.custom_llm_provider, str(dict(str_line))
|
||||
)
|
||||
|
||||
raise litellm.ContentPolicyViolationError(
|
||||
message=error_message,
|
||||
llm_provider=self.custom_llm_provider,
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
# checking for logprobs
|
||||
if (
|
||||
|
@ -9248,7 +9250,10 @@ class CustomStreamWrapper:
|
|||
try:
|
||||
# return this for all models
|
||||
completion_obj = {"content": ""}
|
||||
if self.custom_llm_provider and self.custom_llm_provider == "anthropic":
|
||||
if self.custom_llm_provider and (
|
||||
self.custom_llm_provider == "anthropic"
|
||||
or self.custom_llm_provider in litellm._custom_providers
|
||||
):
|
||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||
|
||||
if self.received_finish_reason is not None:
|
||||
|
@ -10115,6 +10120,7 @@ class CustomStreamWrapper:
|
|||
try:
|
||||
if self.completion_stream is None:
|
||||
await self.fetch_stream()
|
||||
|
||||
if (
|
||||
self.custom_llm_provider == "openai"
|
||||
or self.custom_llm_provider == "azure"
|
||||
|
@ -10139,6 +10145,7 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "triton"
|
||||
or self.custom_llm_provider == "watsonx"
|
||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||
or self.custom_llm_provider in litellm._custom_providers
|
||||
):
|
||||
async for chunk in self.completion_stream:
|
||||
print_verbose(f"value of async chunk: {chunk}")
|
||||
|
@ -10967,3 +10974,8 @@ class ModelResponseIterator:
|
|||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
|
||||
class CustomModelResponseIterator(Iterable):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue