mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
LiteLLM Minor Fixes & Improvements (10/07/2024) (#6101)
* fix(utils.py): support dropping temperature param for azure o1 models * fix(main.py): handle azure o1 streaming requests o1 doesn't support streaming, fake it to ensure code works as expected * feat(utils.py): expose `hosted_vllm/` endpoint, with tool handling for vllm Fixes https://github.com/BerriAI/litellm/issues/6088 * refactor(internal_user_endpoints.py): cleanup unused params + update docstring Closes https://github.com/BerriAI/litellm/issues/6100 * fix(main.py): expose custom image generation api support Fixes https://github.com/BerriAI/litellm/issues/6097 * fix: fix linting errors * docs(custom_llm_server.md): add docs on custom api for image gen calls * fix(types/utils.py): handle dict type * fix(types/utils.py): fix linting errors
This commit is contained in:
parent
5de69cb1b2
commit
6729c9ca7f
17 changed files with 643 additions and 76 deletions
138
litellm/main.py
138
litellm/main.py
|
@ -42,6 +42,7 @@ from litellm import ( # type: ignore
|
|||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
|
@ -89,6 +90,7 @@ from .llms.azure_ai.embed import AzureAIEmbedding
|
|||
from .llms.azure_text import AzureTextCompletion
|
||||
from .llms.AzureOpenAI.audio_transcriptions import AzureAudioTranscription
|
||||
from .llms.AzureOpenAI.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||
from .llms.AzureOpenAI.chat.o1_handler import AzureOpenAIO1ChatCompletion
|
||||
from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore
|
||||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||
|
@ -178,6 +180,7 @@ azure_ai_embedding = AzureAIEmbedding()
|
|||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
anthropic_text_completions = AnthropicTextCompletion()
|
||||
azure_chat_completions = AzureChatCompletion()
|
||||
azure_o1_chat_completions = AzureOpenAIO1ChatCompletion()
|
||||
azure_text_completions = AzureTextCompletion()
|
||||
azure_audio_transcriptions = AzureAudioTranscription()
|
||||
huggingface = Huggingface()
|
||||
|
@ -1064,35 +1067,68 @@ def completion( # type: ignore
|
|||
|
||||
headers = headers or litellm.headers
|
||||
|
||||
## LOAD CONFIG - if set
|
||||
config = litellm.AzureOpenAIConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
if (
|
||||
litellm.enable_preview_features
|
||||
and litellm.AzureOpenAIO1Config().is_o1_model(model=model)
|
||||
):
|
||||
## LOAD CONFIG - if set
|
||||
config = litellm.AzureOpenAIO1Config.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
## COMPLETION CALL
|
||||
response = azure_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_type=api_type,
|
||||
dynamic_params=dynamic_params,
|
||||
azure_ad_token=azure_ad_token,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout, # type: ignore
|
||||
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||
)
|
||||
response = azure_o1_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_type=api_type,
|
||||
dynamic_params=dynamic_params,
|
||||
azure_ad_token=azure_ad_token,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout, # type: ignore
|
||||
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||
)
|
||||
else:
|
||||
## LOAD CONFIG - if set
|
||||
config = litellm.AzureOpenAIConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
## COMPLETION CALL
|
||||
response = azure_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_type=api_type,
|
||||
dynamic_params=dynamic_params,
|
||||
azure_ad_token=azure_ad_token,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
timeout=timeout, # type: ignore
|
||||
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
|
@ -4582,6 +4618,7 @@ def image_generation(
|
|||
Currently supports just Azure + OpenAI.
|
||||
"""
|
||||
try:
|
||||
args = locals()
|
||||
aimg_generation = kwargs.get("aimg_generation", False)
|
||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||
logger_fn = kwargs.get("logger_fn", None)
|
||||
|
@ -4787,6 +4824,51 @@ def image_generation(
|
|||
vertex_credentials=vertex_credentials,
|
||||
aimg_generation=aimg_generation,
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider in litellm._custom_providers
|
||||
): # Assume custom LLM provider
|
||||
# Get the Custom Handler
|
||||
custom_handler: Optional[CustomLLM] = None
|
||||
for item in litellm.custom_provider_map:
|
||||
if item["provider"] == custom_llm_provider:
|
||||
custom_handler = item["custom_handler"]
|
||||
|
||||
if custom_handler is None:
|
||||
raise ValueError(
|
||||
f"Unable to map your input to a model. Check your input - {args}"
|
||||
)
|
||||
|
||||
## ROUTE LLM CALL ##
|
||||
if aimg_generation is True:
|
||||
async_custom_client: Optional[AsyncHTTPHandler] = None
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler):
|
||||
async_custom_client = client
|
||||
|
||||
## CALL FUNCTION
|
||||
model_response = custom_handler.aimage_generation( # type: ignore
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=timeout,
|
||||
client=async_custom_client,
|
||||
)
|
||||
else:
|
||||
custom_client: Optional[HTTPHandler] = None
|
||||
if client is not None and isinstance(client, HTTPHandler):
|
||||
custom_client = client
|
||||
|
||||
## CALL FUNCTION
|
||||
model_response = custom_handler.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=timeout,
|
||||
client=custom_client,
|
||||
)
|
||||
|
||||
return model_response
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue