forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (11/04/2024) (#6572)
* feat: initial commit for watsonx chat endpoint support Closes https://github.com/BerriAI/litellm/issues/6562 * feat(watsonx/chat/handler.py): support tool calling for watsonx Closes https://github.com/BerriAI/litellm/issues/6562 * fix(streaming_utils.py): return empty chunk instead of failing if streaming value is invalid dict ensures streaming works for ibm watsonx * fix(openai_like/chat/handler.py): ensure asynchttphandler is passed correctly for openai like calls * fix: ensure exception mapping works well for watsonx calls * fix(openai_like/chat/handler.py): handle async streaming correctly * feat(main.py): Make it clear when a user is passing an invalid message add validation for user content message Closes https://github.com/BerriAI/litellm/issues/6565 * fix: cleanup * fix(utils.py): loosen validation check, to just make sure content types are valid make litellm robust to future content updates * fix: fix linting erro * fix: fix linting errors * fix(utils.py): make validation check more flexible * test: handle langfuse list index out of range error * Litellm dev 11 02 2024 (#6561) * fix(dual_cache.py): update in-memory check for redis batch get cache Fixes latency delay for async_batch_redis_cache * fix(service_logger.py): fix race condition causing otel service logging to be overwritten if service_callbacks set * feat(user_api_key_auth.py): add parent otel component for auth allows us to isolate how much latency is added by auth checks * perf(parallel_request_limiter.py): move async_set_cache_pipeline (from max parallel request limiter) out of execution path (background task) reduces latency by 200ms * feat(user_api_key_auth.py): have user api key auth object return user tpm/rpm limits - reduces redis calls in downstream task (parallel_request_limiter) Reduces latency by 400-800ms * fix(parallel_request_limiter.py): use batch get cache to reduce user/key/team usage object calls reduces latency by 50-100ms * fix: fix linting error * fix(_service_logger.py): fix import * fix(user_api_key_auth.py): fix service logging * fix(dual_cache.py): don't pass 'self' * fix: fix python3.8 error * fix: fix init] * bump: version 1.51.4 → 1.51.5 * build(deps): bump cookie and express in /docs/my-website (#6566) Bumps [cookie](https://github.com/jshttp/cookie) and [express](https://github.com/expressjs/express). These dependencies needed to be updated together. Updates `cookie` from 0.6.0 to 0.7.1 - [Release notes](https://github.com/jshttp/cookie/releases) - [Commits](https://github.com/jshttp/cookie/compare/v0.6.0...v0.7.1) Updates `express` from 4.20.0 to 4.21.1 - [Release notes](https://github.com/expressjs/express/releases) - [Changelog](https://github.com/expressjs/express/blob/4.21.1/History.md) - [Commits](https://github.com/expressjs/express/compare/4.20.0...4.21.1) --- updated-dependencies: - dependency-name: cookie dependency-type: indirect - dependency-name: express dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * docs(virtual_keys.md): update Dockerfile reference (#6554) Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> * (proxy fix) - call connect on prisma client when running setup (#6534) * critical fix - call connect on prisma client when running setup * fix test_proxy_server_prisma_setup * fix test_proxy_server_prisma_setup * Add 3.5 haiku (#6588) * feat: add claude-3-5-haiku-20241022 entries * feat: add claude-3-5-haiku-20241022 and vertex_ai/claude-3-5-haiku@20241022 models * add missing entries, remove vision * remove image token costs * Litellm perf improvements 3 (#6573) * perf: move writing key to cache, to background task * perf(litellm_pre_call_utils.py): add otel tracing for pre-call utils adds 200ms on calls with pgdb connected * fix(litellm_pre_call_utils.py'): rename call_type to actual call used * perf(proxy_server.py): remove db logic from _get_config_from_file was causing db calls to occur on every llm request, if team_id was set on key * fix(auth_checks.py): add check for reducing db calls if user/team id does not exist in db reduces latency/call by ~100ms * fix(proxy_server.py): minor fix on existing_settings not incl alerting * fix(exception_mapping_utils.py): map databricks exception string * fix(auth_checks.py): fix auth check logic * test: correctly mark flaky test * fix(utils.py): handle auth token error for tokenizers.from_pretrained * build: fix map * build: fix map * build: fix json for model map * Litellm dev 11 02 2024 (#6561) * fix(dual_cache.py): update in-memory check for redis batch get cache Fixes latency delay for async_batch_redis_cache * fix(service_logger.py): fix race condition causing otel service logging to be overwritten if service_callbacks set * feat(user_api_key_auth.py): add parent otel component for auth allows us to isolate how much latency is added by auth checks * perf(parallel_request_limiter.py): move async_set_cache_pipeline (from max parallel request limiter) out of execution path (background task) reduces latency by 200ms * feat(user_api_key_auth.py): have user api key auth object return user tpm/rpm limits - reduces redis calls in downstream task (parallel_request_limiter) Reduces latency by 400-800ms * fix(parallel_request_limiter.py): use batch get cache to reduce user/key/team usage object calls reduces latency by 50-100ms * fix: fix linting error * fix(_service_logger.py): fix import * fix(user_api_key_auth.py): fix service logging * fix(dual_cache.py): don't pass 'self' * fix: fix python3.8 error * fix: fix init] * Litellm perf improvements 3 (#6573) * perf: move writing key to cache, to background task * perf(litellm_pre_call_utils.py): add otel tracing for pre-call utils adds 200ms on calls with pgdb connected * fix(litellm_pre_call_utils.py'): rename call_type to actual call used * perf(proxy_server.py): remove db logic from _get_config_from_file was causing db calls to occur on every llm request, if team_id was set on key * fix(auth_checks.py): add check for reducing db calls if user/team id does not exist in db reduces latency/call by ~100ms * fix(proxy_server.py): minor fix on existing_settings not incl alerting * fix(exception_mapping_utils.py): map databricks exception string * fix(auth_checks.py): fix auth check logic * test: correctly mark flaky test * fix(utils.py): handle auth token error for tokenizers.from_pretrained * fix ImageObject conversion (#6584) * (fix) litellm.text_completion raises a non-blocking error on simple usage (#6546) * unit test test_huggingface_text_completion_logprobs * fix return TextCompletionHandler convert_chat_to_text_completion * fix hf rest api * fix test_huggingface_text_completion_logprobs * fix linting errors * fix importLiteLLMResponseObjectHandler * fix test for LiteLLMResponseObjectHandler * fix test text completion * fix allow using 15 seconds for premium license check * testing fix bedrock deprecated cohere.command-text-v14 * (feat) add `Predicted Outputs` for OpenAI (#6594) * bump openai to openai==1.54.0 * add 'prediction' param * testing fix bedrock deprecated cohere.command-text-v14 * test test_openai_prediction_param.py * test_openai_prediction_param_with_caching * doc Predicted Outputs * doc Predicted Output * (fix) Vertex Improve Performance when using `image_url` (#6593) * fix transformation vertex * test test_process_gemini_image * test_image_completion_request * testing fix - bedrock has deprecated cohere.command-text-v14 * fix vertex pdf * bump: version 1.51.5 → 1.52.0 * fix(lowest_tpm_rpm_routing.py): fix parallel rate limit check (#6577) * fix(lowest_tpm_rpm_routing.py): fix parallel rate limit check * fix(lowest_tpm_rpm_v2.py): return headers in correct format * test: update test * build(deps): bump cookie and express in /docs/my-website (#6566) Bumps [cookie](https://github.com/jshttp/cookie) and [express](https://github.com/expressjs/express). These dependencies needed to be updated together. Updates `cookie` from 0.6.0 to 0.7.1 - [Release notes](https://github.com/jshttp/cookie/releases) - [Commits](https://github.com/jshttp/cookie/compare/v0.6.0...v0.7.1) Updates `express` from 4.20.0 to 4.21.1 - [Release notes](https://github.com/expressjs/express/releases) - [Changelog](https://github.com/expressjs/express/blob/4.21.1/History.md) - [Commits](https://github.com/expressjs/express/compare/4.20.0...4.21.1) --- updated-dependencies: - dependency-name: cookie dependency-type: indirect - dependency-name: express dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * docs(virtual_keys.md): update Dockerfile reference (#6554) Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> * (proxy fix) - call connect on prisma client when running setup (#6534) * critical fix - call connect on prisma client when running setup * fix test_proxy_server_prisma_setup * fix test_proxy_server_prisma_setup * Add 3.5 haiku (#6588) * feat: add claude-3-5-haiku-20241022 entries * feat: add claude-3-5-haiku-20241022 and vertex_ai/claude-3-5-haiku@20241022 models * add missing entries, remove vision * remove image token costs * Litellm perf improvements 3 (#6573) * perf: move writing key to cache, to background task * perf(litellm_pre_call_utils.py): add otel tracing for pre-call utils adds 200ms on calls with pgdb connected * fix(litellm_pre_call_utils.py'): rename call_type to actual call used * perf(proxy_server.py): remove db logic from _get_config_from_file was causing db calls to occur on every llm request, if team_id was set on key * fix(auth_checks.py): add check for reducing db calls if user/team id does not exist in db reduces latency/call by ~100ms * fix(proxy_server.py): minor fix on existing_settings not incl alerting * fix(exception_mapping_utils.py): map databricks exception string * fix(auth_checks.py): fix auth check logic * test: correctly mark flaky test * fix(utils.py): handle auth token error for tokenizers.from_pretrained * build: fix map * build: fix map * build: fix json for model map * test: remove eol model * fix(proxy_server.py): fix db config loading logic * fix(proxy_server.py): fix order of config / db updates, to ensure fields not overwritten * test: skip test if required env var is missing * test: fix test --------- Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: paul-gauthier <69695708+paul-gauthier@users.noreply.github.com> * test: mark flaky test * test: handle anthropic api instability * test: update test * test: bump num retries on langfuse tests - their api is quite bad --------- Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: paul-gauthier <69695708+paul-gauthier@users.noreply.github.com>
This commit is contained in:
parent
0fe8cde7c7
commit
5c55270740
24 changed files with 1510 additions and 554 deletions
|
@ -137,6 +137,8 @@ safe_memory_mode: bool = False
|
|||
enable_azure_ad_token_refresh: Optional[bool] = False
|
||||
### DEFAULT AZURE API VERSION ###
|
||||
AZURE_DEFAULT_API_VERSION = "2024-08-01-preview" # this is updated to the latest
|
||||
### DEFAULT WATSONX API VERSION ###
|
||||
WATSONX_DEFAULT_API_VERSION = "2024-03-13"
|
||||
### COHERE EMBEDDINGS DEFAULT TYPE ###
|
||||
COHERE_DEFAULT_EMBEDDING_INPUT_TYPE: COHERE_EMBEDDING_INPUT_TYPES = "search_document"
|
||||
### GUARDRAILS ###
|
||||
|
@ -282,7 +284,9 @@ priority_reservation: Optional[Dict[str, float]] = None
|
|||
#### RELIABILITY ####
|
||||
REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
|
||||
request_timeout: float = 6000 # time in seconds
|
||||
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
|
||||
module_level_aclient = AsyncHTTPHandler(
|
||||
timeout=request_timeout, client_alias="module level aclient"
|
||||
)
|
||||
module_level_client = HTTPHandler(timeout=request_timeout)
|
||||
num_retries: Optional[int] = None # per model endpoint
|
||||
max_fallbacks: Optional[int] = None
|
||||
|
@ -527,7 +531,11 @@ openai_text_completion_compatible_providers: List = (
|
|||
"hosted_vllm",
|
||||
]
|
||||
)
|
||||
|
||||
_openai_like_providers: List = [
|
||||
"predibase",
|
||||
"databricks",
|
||||
"watsonx",
|
||||
] # private helper. similar to openai but require some custom auth / endpoint handling, so can't use the openai sdk
|
||||
# well supported replicate llms
|
||||
replicate_models: List = [
|
||||
# llama replicate supported LLMs
|
||||
|
@ -1040,7 +1048,8 @@ from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
|||
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
||||
from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
||||
from .llms.AzureOpenAI.chat.o1_transformation import AzureOpenAIO1Config
|
||||
from .llms.watsonx import IBMWatsonXAIConfig
|
||||
from .llms.watsonx.completion.handler import IBMWatsonXAIConfig
|
||||
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
||||
from .main import * # type: ignore
|
||||
from .integrations import *
|
||||
from .exceptions import (
|
||||
|
|
|
@ -612,19 +612,7 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
url="https://api.replicate.com/v1/deployments",
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "watsonx":
|
||||
if "token_quota_reached" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
message=f"WatsonxException: Rate Limit Errror - {error_str}",
|
||||
llm_provider="watsonx",
|
||||
model=model,
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider == "predibase"
|
||||
or custom_llm_provider == "databricks"
|
||||
):
|
||||
elif custom_llm_provider in litellm._openai_like_providers:
|
||||
if "authorization denied for" in error_str:
|
||||
exception_mapping_worked = True
|
||||
|
||||
|
@ -646,6 +634,14 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
response=original_exception.response,
|
||||
litellm_debug_info=extra_information,
|
||||
)
|
||||
elif "token_quota_reached" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise RateLimitError(
|
||||
message=f"{custom_llm_provider}Exception: Rate Limit Errror - {error_str}",
|
||||
llm_provider=custom_llm_provider,
|
||||
model=model,
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif (
|
||||
"The server received an invalid response from an upstream server."
|
||||
in error_str
|
||||
|
|
288
litellm/litellm_core_utils/get_supported_openai_params.py
Normal file
288
litellm/litellm_core_utils/get_supported_openai_params.py
Normal file
|
@ -0,0 +1,288 @@
|
|||
from typing import Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm.exceptions import BadRequestError
|
||||
|
||||
|
||||
def get_supported_openai_params( # noqa: PLR0915
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
request_type: Literal["chat_completion", "embeddings"] = "chat_completion",
|
||||
) -> Optional[list]:
|
||||
"""
|
||||
Returns the supported openai params for a given model + provider
|
||||
|
||||
Example:
|
||||
```
|
||||
get_supported_openai_params(model="anthropic.claude-3", custom_llm_provider="bedrock")
|
||||
```
|
||||
|
||||
Returns:
|
||||
- List if custom_llm_provider is mapped
|
||||
- None if unmapped
|
||||
"""
|
||||
if not custom_llm_provider:
|
||||
try:
|
||||
custom_llm_provider = litellm.get_llm_provider(model=model)[1]
|
||||
except BadRequestError:
|
||||
return None
|
||||
if custom_llm_provider == "bedrock":
|
||||
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "ollama":
|
||||
return litellm.OllamaConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "ollama_chat":
|
||||
return litellm.OllamaChatConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "anthropic":
|
||||
return litellm.AnthropicConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
if request_type == "embeddings":
|
||||
return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
else:
|
||||
return litellm.FireworksAIConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "nvidia_nim":
|
||||
if request_type == "chat_completion":
|
||||
return litellm.nvidiaNimConfig.get_supported_openai_params(model=model)
|
||||
elif request_type == "embeddings":
|
||||
return litellm.nvidiaNimEmbeddingConfig.get_supported_openai_params()
|
||||
elif custom_llm_provider == "cerebras":
|
||||
return litellm.CerebrasConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "xai":
|
||||
return litellm.XAIChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "ai21_chat":
|
||||
return litellm.AI21ChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "volcengine":
|
||||
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "groq":
|
||||
return litellm.GroqChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "hosted_vllm":
|
||||
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "deepseek":
|
||||
return [
|
||||
# https://platform.deepseek.com/api-docs/api/create-chat-completion
|
||||
"frequency_penalty",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"response_format",
|
||||
"stop",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
elif custom_llm_provider == "cohere":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"logit_bias",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"extra_headers",
|
||||
]
|
||||
elif custom_llm_provider == "cohere_chat":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
]
|
||||
elif custom_llm_provider == "maritalk":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
]
|
||||
elif custom_llm_provider == "openai":
|
||||
return litellm.OpenAIConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "azure":
|
||||
if litellm.AzureOpenAIO1Config().is_o1_model(model=model):
|
||||
return litellm.AzureOpenAIO1Config().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
else:
|
||||
return litellm.AzureOpenAIConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "openrouter":
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"repetition_penalty",
|
||||
"seed",
|
||||
"max_tokens",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"response_format",
|
||||
"stop",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
|
||||
# mistal and codestral api have the exact same params
|
||||
if request_type == "chat_completion":
|
||||
return litellm.MistralConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.MistralEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "text-completion-codestral":
|
||||
return litellm.MistralTextCompletionConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "replicate":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"functions",
|
||||
"function_call",
|
||||
]
|
||||
elif custom_llm_provider == "huggingface":
|
||||
return litellm.HuggingfaceConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "together_ai":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"frequency_penalty",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
]
|
||||
elif custom_llm_provider == "ai21":
|
||||
return [
|
||||
"stream",
|
||||
"n",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]
|
||||
elif custom_llm_provider == "databricks":
|
||||
if request_type == "chat_completion":
|
||||
return litellm.DatabricksConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
||||
return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
if request_type == "chat_completion":
|
||||
if model.startswith("meta/"):
|
||||
return litellm.VertexAILlama3Config().get_supported_openai_params()
|
||||
if model.startswith("mistral"):
|
||||
return litellm.MistralConfig().get_supported_openai_params()
|
||||
if model.startswith("codestral"):
|
||||
return (
|
||||
litellm.MistralTextCompletionConfig().get_supported_openai_params()
|
||||
)
|
||||
if model.startswith("claude"):
|
||||
return litellm.VertexAIAnthropicConfig().get_supported_openai_params()
|
||||
return litellm.VertexAIConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "vertex_ai_beta":
|
||||
if request_type == "chat_completion":
|
||||
return litellm.VertexGeminiConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
elif custom_llm_provider == "aleph_alpha":
|
||||
return [
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"top_p",
|
||||
"temperature",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"n",
|
||||
"stop",
|
||||
]
|
||||
elif custom_llm_provider == "cloudflare":
|
||||
return ["max_tokens", "stream"]
|
||||
elif custom_llm_provider == "nlp_cloud":
|
||||
return [
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"n",
|
||||
"stop",
|
||||
]
|
||||
elif custom_llm_provider == "petals":
|
||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
||||
elif custom_llm_provider == "deepinfra":
|
||||
return litellm.DeepInfraConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "perplexity":
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
]
|
||||
elif custom_llm_provider == "anyscale":
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
"max_tokens",
|
||||
"stop",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]
|
||||
elif custom_llm_provider == "watsonx":
|
||||
return litellm.IBMWatsonXChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "custom_openai" or "text-completion-openai":
|
||||
return [
|
||||
"functions",
|
||||
"function_call",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"max_retries",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"extra_headers",
|
||||
]
|
||||
return None
|
|
@ -34,12 +34,14 @@ class AsyncHTTPHandler:
|
|||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]] = None,
|
||||
concurrent_limit=1000,
|
||||
client_alias: Optional[str] = None, # name for client in logs
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.event_hooks = event_hooks
|
||||
self.client = self.create_client(
|
||||
timeout=timeout, concurrent_limit=concurrent_limit, event_hooks=event_hooks
|
||||
)
|
||||
self.client_alias = client_alias
|
||||
|
||||
def create_client(
|
||||
self,
|
||||
|
@ -112,6 +114,7 @@ class AsyncHTTPHandler:
|
|||
try:
|
||||
if timeout is None:
|
||||
timeout = self.timeout
|
||||
|
||||
req = self.client.build_request(
|
||||
"POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@ import json
|
|||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionDeltaChunk,
|
||||
ChatCompletionResponseMessage,
|
||||
|
@ -109,7 +110,17 @@ class ModelResponseIterator:
|
|||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
verbose_logger.debug(
|
||||
f"Error parsing chunk: {e},\nReceived chunk: {chunk}. Defaulting to empty chunk here."
|
||||
)
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
|
@ -123,6 +134,8 @@ class ModelResponseIterator:
|
|||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
chunk = chunk.replace("data:", "")
|
||||
|
@ -144,4 +157,14 @@ class ModelResponseIterator:
|
|||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
verbose_logger.debug(
|
||||
f"Error parsing chunk: {e},\nReceived chunk: {chunk}. Defaulting to empty chunk here."
|
||||
)
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
|
372
litellm/llms/openai_like/chat/handler.py
Normal file
372
litellm/llms/openai_like/chat/handler.py
Normal file
|
@ -0,0 +1,372 @@
|
|||
"""
|
||||
OpenAI-like chat completion handler
|
||||
|
||||
For handling OpenAI-like chat completions, like IBM WatsonX, etc.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.databricks.streaming_utils import ModelResponseIterator
|
||||
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper, EmbeddingResponse
|
||||
|
||||
from ..common_utils import OpenAILikeBase, OpenAILikeError
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
):
|
||||
if client is None:
|
||||
client = litellm.module_level_aclient
|
||||
|
||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if streaming_decoder is not None:
|
||||
completion_stream: Any = streaming_decoder.aiter_bytes(
|
||||
response.aiter_bytes(chunk_size=1024)
|
||||
)
|
||||
else:
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.aiter_lines(), sync_stream=False
|
||||
)
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=completion_stream, # Pass the completion stream for logging
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
def make_sync_call(
|
||||
client: Optional[HTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
):
|
||||
if client is None:
|
||||
client = litellm.module_level_client # Create a new client if none provided
|
||||
|
||||
response = client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise OpenAILikeError(status_code=response.status_code, message=response.read())
|
||||
|
||||
if streaming_decoder is not None:
|
||||
completion_stream = streaming_decoder.iter_bytes(
|
||||
response.iter_bytes(chunk_size=1024)
|
||||
)
|
||||
else:
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class OpenAILikeChatHandler(OpenAILikeBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def acompletion_stream_function(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_llm_provider: str,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
stream,
|
||||
data: dict,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
|
||||
data["stream"] = True
|
||||
completion_stream = await make_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
streaming_decoder=streaming_decoder,
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return streamwrapper
|
||||
|
||||
async def acompletion_function(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
custom_llm_provider: str,
|
||||
print_verbose: Callable,
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
stream,
|
||||
data: dict,
|
||||
base_model: Optional[str],
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> ModelResponse:
|
||||
if timeout is None:
|
||||
timeout = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
if client is None:
|
||||
client = litellm.module_level_aclient
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
api_base, headers=headers, data=json.dumps(data), timeout=timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise OpenAILikeError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
raise OpenAILikeError(status_code=408, message="Timeout error occurred.")
|
||||
except Exception as e:
|
||||
raise OpenAILikeError(status_code=500, message=str(e))
|
||||
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=response_json,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
response = ModelResponse(**response_json)
|
||||
|
||||
response.model = custom_llm_provider + "/" + (response.model or "")
|
||||
|
||||
if base_model is not None:
|
||||
response._hidden_params["model"] = base_model
|
||||
return response
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key: Optional[str],
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
streaming_decoder: Optional[
|
||||
CustomStreamingDecoder
|
||||
] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker
|
||||
):
|
||||
custom_endpoint = custom_endpoint or optional_params.pop(
|
||||
"custom_endpoint", None
|
||||
)
|
||||
base_model: Optional[str] = optional_params.pop("base_model", None)
|
||||
api_base, headers = self._validate_environment(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
endpoint_type="chat_completions",
|
||||
custom_endpoint=custom_endpoint,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
stream: bool = optional_params.get("stream", None) or False
|
||||
optional_params["stream"] = stream
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
if acompletion is True:
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
client = None
|
||||
if (
|
||||
stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
data["stream"] = stream
|
||||
return self.acompletion_stream_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
client=client,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
streaming_decoder=streaming_decoder,
|
||||
)
|
||||
else:
|
||||
return self.acompletion_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
base_model=base_model,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
## COMPLETION CALL
|
||||
if stream is True:
|
||||
completion_stream = make_sync_call(
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
streaming_decoder=streaming_decoder,
|
||||
)
|
||||
# completion_stream.__iter__()
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = HTTPHandler(timeout=timeout) # type: ignore
|
||||
try:
|
||||
response = client.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise OpenAILikeError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
raise OpenAILikeError(
|
||||
status_code=408, message="Timeout error occurred."
|
||||
)
|
||||
except Exception as e:
|
||||
raise OpenAILikeError(status_code=500, message=str(e))
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=response_json,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
response = ModelResponse(**response_json)
|
||||
|
||||
response.model = custom_llm_provider + "/" + (response.model or "")
|
||||
|
||||
if base_model is not None:
|
||||
response._hidden_params["model"] = base_model
|
||||
|
||||
return response
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Literal, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
|
@ -10,3 +12,43 @@ class OpenAILikeError(Exception):
|
|||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class OpenAILikeBase:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
endpoint_type: Literal["chat_completions", "embeddings"],
|
||||
headers: Optional[dict],
|
||||
custom_endpoint: Optional[bool],
|
||||
) -> Tuple[str, dict]:
|
||||
if api_key is None and headers is None:
|
||||
raise OpenAILikeError(
|
||||
status_code=400,
|
||||
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise OpenAILikeError(
|
||||
status_code=400,
|
||||
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
|
||||
if headers is None:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if api_key is not None:
|
||||
headers.update({"Authorization": "Bearer {}".format(api_key)})
|
||||
|
||||
if not custom_endpoint:
|
||||
if endpoint_type == "chat_completions":
|
||||
api_base = "{}/chat/completions".format(api_base)
|
||||
elif endpoint_type == "embeddings":
|
||||
api_base = "{}/embeddings".format(api_base)
|
||||
return api_base, headers
|
||||
|
|
|
@ -23,46 +23,13 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
)
|
||||
from litellm.utils import EmbeddingResponse
|
||||
|
||||
from ..common_utils import OpenAILikeError
|
||||
from ..common_utils import OpenAILikeBase, OpenAILikeError
|
||||
|
||||
|
||||
class OpenAILikeEmbeddingHandler:
|
||||
class OpenAILikeEmbeddingHandler(OpenAILikeBase):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
endpoint_type: Literal["chat_completions", "embeddings"],
|
||||
headers: Optional[dict],
|
||||
) -> Tuple[str, dict]:
|
||||
if api_key is None and headers is None:
|
||||
raise OpenAILikeError(
|
||||
status_code=400,
|
||||
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise OpenAILikeError(
|
||||
status_code=400,
|
||||
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
|
||||
if headers is None:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if api_key is not None:
|
||||
headers.update({"Authorization": "Bearer {}".format(api_key)})
|
||||
|
||||
if endpoint_type == "chat_completions":
|
||||
api_base = "{}/chat/completions".format(api_base)
|
||||
elif endpoint_type == "embeddings":
|
||||
api_base = "{}/embeddings".format(api_base)
|
||||
return api_base, headers
|
||||
|
||||
async def aembedding(
|
||||
self,
|
||||
input: list,
|
||||
|
@ -133,6 +100,7 @@ class OpenAILikeEmbeddingHandler:
|
|||
model_response: Optional[litellm.utils.EmbeddingResponse] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> EmbeddingResponse:
|
||||
api_base, headers = self._validate_environment(
|
||||
|
@ -140,6 +108,7 @@ class OpenAILikeEmbeddingHandler:
|
|||
api_key=api_key,
|
||||
endpoint_type="embeddings",
|
||||
headers=headers,
|
||||
custom_endpoint=custom_endpoint,
|
||||
)
|
||||
model = model
|
||||
data = {"model": model, "input": input, **optional_params}
|
||||
|
|
123
litellm/llms/watsonx/chat/handler.py
Normal file
123
litellm/llms/watsonx/chat/handler.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
from typing import Callable, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams
|
||||
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
|
||||
|
||||
from ...openai_like.chat.handler import OpenAILikeChatHandler
|
||||
from ..common_utils import WatsonXAIError, _get_api_params
|
||||
|
||||
|
||||
class WatsonXChatHandler(OpenAILikeChatHandler):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _prepare_url(
|
||||
self, model: str, api_params: WatsonXAPIParams, stream: Optional[bool]
|
||||
) -> str:
|
||||
if model.startswith("deployment/"):
|
||||
if api_params.get("space_id") is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=api_params["url"],
|
||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||
)
|
||||
deployment_id = "/".join(model.split("/")[1:])
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
||||
if stream is True
|
||||
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
|
||||
)
|
||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||
else:
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.CHAT_STREAM.value
|
||||
if stream is True
|
||||
else WatsonXAIEndpoint.CHAT.value
|
||||
)
|
||||
base_url = httpx.URL(api_params["url"])
|
||||
base_url = base_url.join(endpoint)
|
||||
full_url = str(
|
||||
base_url.copy_add_param(key="version", value=api_params["api_version"])
|
||||
)
|
||||
|
||||
return full_url
|
||||
|
||||
def _prepare_payload(
|
||||
self, model: str, api_params: WatsonXAPIParams, stream: Optional[bool]
|
||||
) -> dict:
|
||||
payload: dict = {}
|
||||
if model.startswith("deployment/"):
|
||||
return payload
|
||||
payload["model_id"] = model
|
||||
payload["project_id"] = api_params["project_id"]
|
||||
return payload
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key: Optional[str],
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
streaming_decoder: Optional[
|
||||
CustomStreamingDecoder
|
||||
] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker
|
||||
):
|
||||
api_params = _get_api_params(optional_params, print_verbose=print_verbose)
|
||||
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_params['token']}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
)
|
||||
|
||||
stream: Optional[bool] = optional_params.get("stream", False)
|
||||
|
||||
## get api url and payload
|
||||
api_base = self._prepare_url(model=model, api_params=api_params, stream=stream)
|
||||
watsonx_auth_payload = self._prepare_payload(
|
||||
model=model, api_params=api_params, stream=stream
|
||||
)
|
||||
optional_params.update(watsonx_auth_payload)
|
||||
|
||||
return super().completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
custom_endpoint=True,
|
||||
streaming_decoder=streaming_decoder,
|
||||
)
|
82
litellm/llms/watsonx/chat/transformation.py
Normal file
82
litellm/llms/watsonx/chat/transformation.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
"""
|
||||
Translation from OpenAI's `/chat/completions` endpoint to IBM WatsonX's `/text/chat` endpoint.
|
||||
|
||||
Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
|
||||
|
||||
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
||||
from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class IBMWatsonXChatConfig(OpenAIGPTConfig):
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"temperature", # equivalent to temperature
|
||||
"max_tokens", # equivalent to max_new_tokens
|
||||
"top_p", # equivalent to top_p
|
||||
"frequency_penalty", # equivalent to repetition_penalty
|
||||
"stop", # equivalent to stop_sequences
|
||||
"seed", # equivalent to random_seed
|
||||
"stream", # equivalent to stream
|
||||
"tools",
|
||||
"tool_choice", # equivalent to tool_choice + tool_choice_options
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def is_tool_choice_option(self, tool_choice: Optional[Union[str, dict]]) -> bool:
|
||||
if tool_choice is None:
|
||||
return False
|
||||
if isinstance(tool_choice, str):
|
||||
return tool_choice in ["auto", "none", "required"]
|
||||
return False
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
## TOOLS ##
|
||||
_tools = non_default_params.pop("tools", None)
|
||||
if _tools is not None:
|
||||
# remove 'additionalProperties' from tools
|
||||
_tools = _remove_additional_properties(_tools)
|
||||
# remove 'strict' from tools
|
||||
_tools = _remove_strict_from_schema(_tools)
|
||||
if _tools is not None:
|
||||
non_default_params["tools"] = _tools
|
||||
|
||||
## TOOL CHOICE ##
|
||||
|
||||
_tool_choice = non_default_params.pop("tool_choice", None)
|
||||
if self.is_tool_choice_option(_tool_choice):
|
||||
optional_params["tool_choice_options"] = _tool_choice
|
||||
elif _tool_choice is not None:
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
return super().map_openai_params(
|
||||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE") # type: ignore
|
||||
dynamic_api_key = (
|
||||
api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
|
||||
) # vllm does not require an api key
|
||||
return api_base, dynamic_api_key
|
172
litellm/llms/watsonx/common_utils.py
Normal file
172
litellm/llms/watsonx/common_utils.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
from typing import Callable, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.watsonx import WatsonXAPIParams
|
||||
|
||||
|
||||
class WatsonXAIError(Exception):
|
||||
def __init__(self, status_code, message, url: Optional[str] = None):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
url = url or "https://https://us-south.ml.cloud.ibm.com"
|
||||
self.request = httpx.Request(method="POST", url=url)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
iam_token_cache = InMemoryCache()
|
||||
|
||||
|
||||
def generate_iam_token(api_key=None, **params) -> str:
|
||||
result: Optional[str] = iam_token_cache.get_cache(api_key) # type: ignore
|
||||
|
||||
if result is None:
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError("API key is required")
|
||||
headers["Accept"] = "application/json"
|
||||
data = {
|
||||
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||
"apikey": api_key,
|
||||
}
|
||||
verbose_logger.debug(
|
||||
"calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s",
|
||||
"https://iam.cloud.ibm.com/identity/token",
|
||||
headers,
|
||||
data,
|
||||
)
|
||||
response = httpx.post(
|
||||
"https://iam.cloud.ibm.com/identity/token", data=data, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
json_data = response.json()
|
||||
|
||||
result = json_data["access_token"]
|
||||
iam_token_cache.set_cache(
|
||||
key=api_key,
|
||||
value=result,
|
||||
ttl=json_data["expires_in"] - 10, # leave some buffer
|
||||
)
|
||||
|
||||
return cast(str, result)
|
||||
|
||||
|
||||
def _get_api_params(
|
||||
params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
generate_token: Optional[bool] = True,
|
||||
) -> WatsonXAPIParams:
|
||||
"""
|
||||
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
||||
"""
|
||||
# Load auth variables from params
|
||||
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
|
||||
api_key = params.pop("apikey", None)
|
||||
token = params.pop("token", None)
|
||||
project_id = params.pop(
|
||||
"project_id", params.pop("watsonx_project", None)
|
||||
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
|
||||
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
|
||||
region_name = params.pop("region_name", params.pop("region", None))
|
||||
if region_name is None:
|
||||
region_name = params.pop(
|
||||
"watsonx_region_name", params.pop("watsonx_region", None)
|
||||
) # consistent with how vertex ai + aws regions are accepted
|
||||
wx_credentials = params.pop(
|
||||
"wx_credentials",
|
||||
params.pop(
|
||||
"watsonx_credentials", None
|
||||
), # follow {provider}_credentials, same as vertex ai
|
||||
)
|
||||
api_version = params.pop("api_version", litellm.WATSONX_DEFAULT_API_VERSION)
|
||||
# Load auth variables from environment variables
|
||||
if url is None:
|
||||
url = (
|
||||
get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret_str("WATSONX_APIKEY")
|
||||
or get_secret_str("WATSONX_API_KEY")
|
||||
or get_secret_str("WX_API_KEY")
|
||||
)
|
||||
if token is None:
|
||||
token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
|
||||
if project_id is None:
|
||||
project_id = (
|
||||
get_secret_str("WATSONX_PROJECT_ID")
|
||||
or get_secret_str("WX_PROJECT_ID")
|
||||
or get_secret_str("PROJECT_ID")
|
||||
)
|
||||
if region_name is None:
|
||||
region_name = (
|
||||
get_secret_str("WATSONX_REGION")
|
||||
or get_secret_str("WX_REGION")
|
||||
or get_secret_str("REGION")
|
||||
)
|
||||
if space_id is None:
|
||||
space_id = (
|
||||
get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
|
||||
or get_secret_str("WATSONX_SPACE_ID")
|
||||
or get_secret_str("WX_SPACE_ID")
|
||||
or get_secret_str("SPACE_ID")
|
||||
)
|
||||
|
||||
# credentials parsing
|
||||
if wx_credentials is not None:
|
||||
url = wx_credentials.get("url", url)
|
||||
api_key = wx_credentials.get("apikey", wx_credentials.get("api_key", api_key))
|
||||
token = wx_credentials.get(
|
||||
"token",
|
||||
wx_credentials.get(
|
||||
"watsonx_token", token
|
||||
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||
)
|
||||
|
||||
# verify that all required credentials are present
|
||||
if url is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
|
||||
)
|
||||
|
||||
if token is None and api_key is not None and generate_token:
|
||||
# generate the auth token
|
||||
if print_verbose is not None:
|
||||
print_verbose("Generating IAM token for Watsonx.ai")
|
||||
token = generate_iam_token(api_key)
|
||||
elif token is None and api_key is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=url,
|
||||
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
|
||||
)
|
||||
if project_id is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=url,
|
||||
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
|
||||
)
|
||||
|
||||
return WatsonXAPIParams(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
token=cast(str, token),
|
||||
project_id=project_id,
|
||||
space_id=space_id,
|
||||
region_name=region_name,
|
||||
api_version=api_version,
|
||||
)
|
|
@ -26,22 +26,12 @@ import requests # type: ignore
|
|||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
||||
|
||||
from .base import BaseLLM
|
||||
from .prompt_templates import factory as ptf
|
||||
|
||||
|
||||
class WatsonXAIError(Exception):
|
||||
def __init__(self, status_code, message, url: Optional[str] = None):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
url = url or "https://https://us-south.ml.cloud.ibm.com"
|
||||
self.request = httpx.Request(method="POST", url=url)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
from ...base import BaseLLM
|
||||
from ...prompt_templates import factory as ptf
|
||||
from ..common_utils import WatsonXAIError, _get_api_params, generate_iam_token
|
||||
|
||||
|
||||
class IBMWatsonXAIConfig:
|
||||
|
@ -140,6 +130,29 @@ class IBMWatsonXAIConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def is_watsonx_text_param(self, param: str) -> bool:
|
||||
"""
|
||||
Determine if user passed in a watsonx.ai text generation param
|
||||
"""
|
||||
text_generation_params = [
|
||||
"decoding_method",
|
||||
"max_new_tokens",
|
||||
"min_new_tokens",
|
||||
"length_penalty",
|
||||
"stop_sequences",
|
||||
"top_k",
|
||||
"repetition_penalty",
|
||||
"truncate_input_tokens",
|
||||
"include_stop_sequences",
|
||||
"return_options",
|
||||
"random_seed",
|
||||
"moderations",
|
||||
"decoding_method",
|
||||
"min_tokens",
|
||||
]
|
||||
|
||||
return param in text_generation_params
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"temperature", # equivalent to temperature
|
||||
|
@ -151,6 +164,44 @@ class IBMWatsonXAIConfig:
|
|||
"stream", # equivalent to stream
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
extra_body = {}
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_new_tokens"] = v
|
||||
elif k == "stream":
|
||||
optional_params["stream"] = v
|
||||
elif k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
elif k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
elif k == "frequency_penalty":
|
||||
optional_params["repetition_penalty"] = v
|
||||
elif k == "seed":
|
||||
optional_params["random_seed"] = v
|
||||
elif k == "stop":
|
||||
optional_params["stop_sequences"] = v
|
||||
elif k == "decoding_method":
|
||||
extra_body["decoding_method"] = v
|
||||
elif k == "min_tokens":
|
||||
extra_body["min_new_tokens"] = v
|
||||
elif k == "top_k":
|
||||
extra_body["top_k"] = v
|
||||
elif k == "truncate_input_tokens":
|
||||
extra_body["truncate_input_tokens"] = v
|
||||
elif k == "length_penalty":
|
||||
extra_body["length_penalty"] = v
|
||||
elif k == "time_limit":
|
||||
extra_body["time_limit"] = v
|
||||
elif k == "return_options":
|
||||
extra_body["return_options"] = v
|
||||
|
||||
if extra_body:
|
||||
optional_params["extra_body"] = extra_body
|
||||
return optional_params
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||
|
@ -212,18 +263,6 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) ->
|
|||
return prompt
|
||||
|
||||
|
||||
class WatsonXAIEndpoint(str, Enum):
|
||||
TEXT_GENERATION = "/ml/v1/text/generation"
|
||||
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
|
||||
DEPLOYMENT_TEXT_GENERATION = "/ml/v1/deployments/{deployment_id}/text/generation"
|
||||
DEPLOYMENT_TEXT_GENERATION_STREAM = (
|
||||
"/ml/v1/deployments/{deployment_id}/text/generation_stream"
|
||||
)
|
||||
EMBEDDINGS = "/ml/v1/text/embeddings"
|
||||
PROMPTS = "/ml/v1/prompts"
|
||||
AVAILABLE_MODELS = "/ml/v1/foundation_model_specs"
|
||||
|
||||
|
||||
class IBMWatsonXAI(BaseLLM):
|
||||
"""
|
||||
Class to interface with IBM watsonx.ai API for text generation and embeddings.
|
||||
|
@ -247,10 +286,10 @@ class IBMWatsonXAI(BaseLLM):
|
|||
"""
|
||||
Get the request parameters for text generation.
|
||||
"""
|
||||
api_params = self._get_api_params(optional_params, print_verbose=print_verbose)
|
||||
api_params = _get_api_params(optional_params, print_verbose=print_verbose)
|
||||
# build auth headers
|
||||
api_token = api_params.get("token")
|
||||
|
||||
self.token = api_token
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
|
@ -294,118 +333,6 @@ class IBMWatsonXAI(BaseLLM):
|
|||
method="POST", url=url, headers=headers, json=payload, params=request_params
|
||||
)
|
||||
|
||||
def _get_api_params(
|
||||
self,
|
||||
params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
generate_token: Optional[bool] = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
||||
"""
|
||||
# Load auth variables from params
|
||||
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
|
||||
api_key = params.pop("apikey", None)
|
||||
token = params.pop("token", None)
|
||||
project_id = params.pop(
|
||||
"project_id", params.pop("watsonx_project", None)
|
||||
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
|
||||
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
|
||||
region_name = params.pop("region_name", params.pop("region", None))
|
||||
if region_name is None:
|
||||
region_name = params.pop(
|
||||
"watsonx_region_name", params.pop("watsonx_region", None)
|
||||
) # consistent with how vertex ai + aws regions are accepted
|
||||
wx_credentials = params.pop(
|
||||
"wx_credentials",
|
||||
params.pop(
|
||||
"watsonx_credentials", None
|
||||
), # follow {provider}_credentials, same as vertex ai
|
||||
)
|
||||
api_version = params.pop("api_version", IBMWatsonXAI.api_version)
|
||||
# Load auth variables from environment variables
|
||||
if url is None:
|
||||
url = (
|
||||
get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret_str("WATSONX_APIKEY")
|
||||
or get_secret_str("WATSONX_API_KEY")
|
||||
or get_secret_str("WX_API_KEY")
|
||||
)
|
||||
if token is None:
|
||||
token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
|
||||
if project_id is None:
|
||||
project_id = (
|
||||
get_secret_str("WATSONX_PROJECT_ID")
|
||||
or get_secret_str("WX_PROJECT_ID")
|
||||
or get_secret_str("PROJECT_ID")
|
||||
)
|
||||
if region_name is None:
|
||||
region_name = (
|
||||
get_secret_str("WATSONX_REGION")
|
||||
or get_secret_str("WX_REGION")
|
||||
or get_secret_str("REGION")
|
||||
)
|
||||
if space_id is None:
|
||||
space_id = (
|
||||
get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
|
||||
or get_secret_str("WATSONX_SPACE_ID")
|
||||
or get_secret_str("WX_SPACE_ID")
|
||||
or get_secret_str("SPACE_ID")
|
||||
)
|
||||
|
||||
# credentials parsing
|
||||
if wx_credentials is not None:
|
||||
url = wx_credentials.get("url", url)
|
||||
api_key = wx_credentials.get(
|
||||
"apikey", wx_credentials.get("api_key", api_key)
|
||||
)
|
||||
token = wx_credentials.get(
|
||||
"token",
|
||||
wx_credentials.get(
|
||||
"watsonx_token", token
|
||||
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||
)
|
||||
|
||||
# verify that all required credentials are present
|
||||
if url is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
|
||||
)
|
||||
if token is None and api_key is not None and generate_token:
|
||||
# generate the auth token
|
||||
if print_verbose is not None:
|
||||
print_verbose("Generating IAM token for Watsonx.ai")
|
||||
token = self.generate_iam_token(api_key)
|
||||
elif token is None and api_key is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=url,
|
||||
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
|
||||
)
|
||||
if project_id is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
url=url,
|
||||
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
|
||||
)
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"api_key": api_key,
|
||||
"token": token,
|
||||
"project_id": project_id,
|
||||
"space_id": space_id,
|
||||
"region_name": region_name,
|
||||
"api_version": api_version,
|
||||
}
|
||||
|
||||
def _process_text_gen_response(
|
||||
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
|
||||
) -> ModelResponse:
|
||||
|
@ -616,9 +543,10 @@ class IBMWatsonXAI(BaseLLM):
|
|||
input = [input]
|
||||
if api_key is not None:
|
||||
optional_params["api_key"] = api_key
|
||||
api_params = self._get_api_params(optional_params)
|
||||
api_params = _get_api_params(optional_params)
|
||||
# build auth headers
|
||||
api_token = api_params.get("token")
|
||||
self.token = api_token
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
|
@ -664,29 +592,9 @@ class IBMWatsonXAI(BaseLLM):
|
|||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
|
||||
def generate_iam_token(self, api_key=None, **params):
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError("API key is required")
|
||||
headers["Accept"] = "application/json"
|
||||
data = {
|
||||
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||
"apikey": api_key,
|
||||
}
|
||||
response = httpx.post(
|
||||
"https://iam.cloud.ibm.com/identity/token", data=data, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
json_data = response.json()
|
||||
iam_access_token = json_data["access_token"]
|
||||
self.token = iam_access_token
|
||||
return iam_access_token
|
||||
|
||||
def get_available_models(self, *, ids_only: bool = True, **params):
|
||||
api_params = self._get_api_params(params)
|
||||
api_params = _get_api_params(params)
|
||||
self.token = api_params["token"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_params['token']}",
|
||||
"Content-Type": "application/json",
|
|
@ -77,6 +77,7 @@ from litellm.utils import (
|
|||
read_config_args,
|
||||
supports_httpx_timeout,
|
||||
token_counter,
|
||||
validate_chat_completion_user_messages,
|
||||
)
|
||||
|
||||
from ._logging import verbose_logger
|
||||
|
@ -157,7 +158,8 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
|||
from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
|
||||
VertexEmbedding,
|
||||
)
|
||||
from .llms.watsonx import IBMWatsonXAI
|
||||
from .llms.watsonx.chat.handler import WatsonXChatHandler
|
||||
from .llms.watsonx.completion.handler import IBMWatsonXAI
|
||||
from .types.llms.openai import (
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionAudioParam,
|
||||
|
@ -221,6 +223,7 @@ vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
|||
vertex_text_to_speech = VertexTextToSpeechAPI()
|
||||
watsonxai = IBMWatsonXAI()
|
||||
sagemaker_llm = SagemakerLLM()
|
||||
watsonx_chat_completion = WatsonXChatHandler()
|
||||
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
@ -921,6 +924,9 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
"aws_region_name", None
|
||||
) # support region-based pricing for bedrock
|
||||
|
||||
### VALIDATE USER MESSAGES ###
|
||||
validate_chat_completion_user_messages(messages=messages)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
@ -2615,6 +2621,26 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
## RESPONSE OBJECT
|
||||
response = response
|
||||
elif custom_llm_provider == "watsonx":
|
||||
response = watsonx_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
encoding=encoding,
|
||||
custom_llm_provider="watsonx",
|
||||
)
|
||||
elif custom_llm_provider == "watsonx_text":
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
response = watsonxai.completion(
|
||||
model=model,
|
||||
|
|
|
@ -20,6 +20,9 @@ from openai.types.beta.threads.message_content import MessageContent
|
|||
from openai.types.beta.threads.run import Run
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_audio_param import ChatCompletionAudioParam
|
||||
from openai.types.chat.chat_completion_content_part_input_audio_param import (
|
||||
ChatCompletionContentPartInputAudioParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_modality import ChatCompletionModality
|
||||
from openai.types.chat.chat_completion_prediction_content_param import (
|
||||
ChatCompletionPredictionContentParam,
|
||||
|
@ -355,8 +358,19 @@ class ChatCompletionImageObject(TypedDict):
|
|||
image_url: Union[str, ChatCompletionImageUrlObject]
|
||||
|
||||
|
||||
class ChatCompletionAudioObject(ChatCompletionContentPartInputAudioParam):
|
||||
pass
|
||||
|
||||
|
||||
OpenAIMessageContent = Union[
|
||||
str, Iterable[Union[ChatCompletionTextObject, ChatCompletionImageObject]]
|
||||
str,
|
||||
Iterable[
|
||||
Union[
|
||||
ChatCompletionTextObject,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionAudioObject,
|
||||
]
|
||||
],
|
||||
]
|
||||
|
||||
# The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||
|
@ -412,6 +426,12 @@ class ChatCompletionSystemMessage(OpenAIChatCompletionSystemMessage, total=False
|
|||
cache_control: ChatCompletionCachedContent
|
||||
|
||||
|
||||
ValidUserMessageContentTypes = [
|
||||
"text",
|
||||
"image_url",
|
||||
"input_audio",
|
||||
] # used for validating user messages. Prevent users from accidentally sending anthropic messages.
|
||||
|
||||
AllMessageValues = Union[
|
||||
ChatCompletionUserMessage,
|
||||
ChatCompletionAssistantMessage,
|
||||
|
|
31
litellm/types/llms/watsonx.py
Normal file
31
litellm/types/llms/watsonx.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, TypedDict, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WatsonXAPIParams(TypedDict):
|
||||
url: str
|
||||
api_key: Optional[str]
|
||||
token: str
|
||||
project_id: str
|
||||
space_id: Optional[str]
|
||||
region_name: Optional[str]
|
||||
api_version: str
|
||||
|
||||
|
||||
class WatsonXAIEndpoint(str, Enum):
|
||||
TEXT_GENERATION = "/ml/v1/text/generation"
|
||||
TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream"
|
||||
CHAT = "/ml/v1/text/chat"
|
||||
CHAT_STREAM = "/ml/v1/text/chat_stream"
|
||||
DEPLOYMENT_TEXT_GENERATION = "/ml/v1/deployments/{deployment_id}/text/generation"
|
||||
DEPLOYMENT_TEXT_GENERATION_STREAM = (
|
||||
"/ml/v1/deployments/{deployment_id}/text/generation_stream"
|
||||
)
|
||||
DEPLOYMENT_CHAT = "/ml/v1/deployments/{deployment_id}/text/chat"
|
||||
DEPLOYMENT_CHAT_STREAM = "/ml/v1/deployments/{deployment_id}/text/chat_stream"
|
||||
EMBEDDINGS = "/ml/v1/text/embeddings"
|
||||
PROMPTS = "/ml/v1/prompts"
|
||||
AVAILABLE_MODELS = "/ml/v1/foundation_model_specs"
|
399
litellm/utils.py
399
litellm/utils.py
|
@ -69,6 +69,9 @@ from litellm.litellm_core_utils.get_llm_provider_logic import (
|
|||
_is_non_openai_azure_model,
|
||||
get_llm_provider,
|
||||
)
|
||||
from litellm.litellm_core_utils.get_supported_openai_params import (
|
||||
get_supported_openai_params,
|
||||
)
|
||||
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
|
||||
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||
LiteLLMResponseObjectHandler,
|
||||
|
@ -962,9 +965,10 @@ def client(original_function): # noqa: PLR0915
|
|||
result._hidden_params["additional_headers"] = process_response_headers(
|
||||
result._hidden_params.get("additional_headers") or {}
|
||||
) # GUARANTEE OPENAI HEADERS IN RESPONSE
|
||||
result._response_ms = (
|
||||
end_time - start_time
|
||||
).total_seconds() * 1000 # return response latency in ms like openai
|
||||
if result is not None:
|
||||
result._response_ms = (
|
||||
end_time - start_time
|
||||
).total_seconds() * 1000 # return response latency in ms like openai
|
||||
return result
|
||||
except Exception as e:
|
||||
call_type = original_function.__name__
|
||||
|
@ -3622,43 +3626,30 @@ def get_optional_params( # noqa: PLR0915
|
|||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
if max_tokens is not None:
|
||||
optional_params["max_new_tokens"] = max_tokens
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
optional_params["top_p"] = top_p
|
||||
if frequency_penalty is not None:
|
||||
optional_params["repetition_penalty"] = frequency_penalty
|
||||
if seed is not None:
|
||||
optional_params["random_seed"] = seed
|
||||
if stop is not None:
|
||||
optional_params["stop_sequences"] = stop
|
||||
|
||||
# WatsonX-only parameters
|
||||
extra_body = {}
|
||||
if "decoding_method" in passed_params:
|
||||
extra_body["decoding_method"] = passed_params.pop("decoding_method")
|
||||
if "min_tokens" in passed_params or "min_new_tokens" in passed_params:
|
||||
extra_body["min_new_tokens"] = passed_params.pop(
|
||||
"min_tokens", passed_params.pop("min_new_tokens")
|
||||
)
|
||||
if "top_k" in passed_params:
|
||||
extra_body["top_k"] = passed_params.pop("top_k")
|
||||
if "truncate_input_tokens" in passed_params:
|
||||
extra_body["truncate_input_tokens"] = passed_params.pop(
|
||||
"truncate_input_tokens"
|
||||
)
|
||||
if "length_penalty" in passed_params:
|
||||
extra_body["length_penalty"] = passed_params.pop("length_penalty")
|
||||
if "time_limit" in passed_params:
|
||||
extra_body["time_limit"] = passed_params.pop("time_limit")
|
||||
if "return_options" in passed_params:
|
||||
extra_body["return_options"] = passed_params.pop("return_options")
|
||||
optional_params["extra_body"] = (
|
||||
extra_body # openai client supports `extra_body` param
|
||||
optional_params = litellm.IBMWatsonXChatConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
# WatsonX-text param check
|
||||
for param in passed_params.keys():
|
||||
if litellm.IBMWatsonXAIConfig().is_watsonx_text_param(param):
|
||||
raise ValueError(
|
||||
f"LiteLLM now defaults to Watsonx's `/text/chat` endpoint. Please use the `watsonx_text` provider instead, to call the `/text/generation` endpoint. Param: {param}"
|
||||
)
|
||||
elif custom_llm_provider == "watsonx_text":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.IBMWatsonXAIConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
elif custom_llm_provider == "openai":
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -4160,290 +4151,6 @@ def get_first_chars_messages(kwargs: dict) -> str:
|
|||
return ""
|
||||
|
||||
|
||||
def get_supported_openai_params( # noqa: PLR0915
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
request_type: Literal["chat_completion", "embeddings"] = "chat_completion",
|
||||
) -> Optional[list]:
|
||||
"""
|
||||
Returns the supported openai params for a given model + provider
|
||||
|
||||
Example:
|
||||
```
|
||||
get_supported_openai_params(model="anthropic.claude-3", custom_llm_provider="bedrock")
|
||||
```
|
||||
|
||||
Returns:
|
||||
- List if custom_llm_provider is mapped
|
||||
- None if unmapped
|
||||
"""
|
||||
if not custom_llm_provider:
|
||||
try:
|
||||
custom_llm_provider = litellm.get_llm_provider(model=model)[1]
|
||||
except BadRequestError:
|
||||
return None
|
||||
if custom_llm_provider == "bedrock":
|
||||
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "ollama":
|
||||
return litellm.OllamaConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "ollama_chat":
|
||||
return litellm.OllamaChatConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "anthropic":
|
||||
return litellm.AnthropicConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "fireworks_ai":
|
||||
if request_type == "embeddings":
|
||||
return litellm.FireworksAIEmbeddingConfig().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
else:
|
||||
return litellm.FireworksAIConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "nvidia_nim":
|
||||
if request_type == "chat_completion":
|
||||
return litellm.nvidiaNimConfig.get_supported_openai_params(model=model)
|
||||
elif request_type == "embeddings":
|
||||
return litellm.nvidiaNimEmbeddingConfig.get_supported_openai_params()
|
||||
elif custom_llm_provider == "cerebras":
|
||||
return litellm.CerebrasConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "xai":
|
||||
return litellm.XAIChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "ai21_chat":
|
||||
return litellm.AI21ChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "volcengine":
|
||||
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "groq":
|
||||
return litellm.GroqChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "hosted_vllm":
|
||||
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "deepseek":
|
||||
return [
|
||||
# https://platform.deepseek.com/api-docs/api/create-chat-completion
|
||||
"frequency_penalty",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"response_format",
|
||||
"stop",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
elif custom_llm_provider == "cohere":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"logit_bias",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"extra_headers",
|
||||
]
|
||||
elif custom_llm_provider == "cohere_chat":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
]
|
||||
elif custom_llm_provider == "maritalk":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
]
|
||||
elif custom_llm_provider == "openai":
|
||||
return litellm.OpenAIConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "azure":
|
||||
if litellm.AzureOpenAIO1Config().is_o1_model(model=model):
|
||||
return litellm.AzureOpenAIO1Config().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
else:
|
||||
return litellm.AzureOpenAIConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "openrouter":
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"repetition_penalty",
|
||||
"seed",
|
||||
"max_tokens",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"response_format",
|
||||
"stop",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
|
||||
# mistal and codestral api have the exact same params
|
||||
if request_type == "chat_completion":
|
||||
return litellm.MistralConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.MistralEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "text-completion-codestral":
|
||||
return litellm.MistralTextCompletionConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "replicate":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"functions",
|
||||
"function_call",
|
||||
]
|
||||
elif custom_llm_provider == "huggingface":
|
||||
return litellm.HuggingfaceConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "together_ai":
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"frequency_penalty",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
]
|
||||
elif custom_llm_provider == "ai21":
|
||||
return [
|
||||
"stream",
|
||||
"n",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]
|
||||
elif custom_llm_provider == "databricks":
|
||||
if request_type == "chat_completion":
|
||||
return litellm.DatabricksConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
||||
return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
if request_type == "chat_completion":
|
||||
if model.startswith("meta/"):
|
||||
return litellm.VertexAILlama3Config().get_supported_openai_params()
|
||||
if model.startswith("mistral"):
|
||||
return litellm.MistralConfig().get_supported_openai_params()
|
||||
if model.startswith("codestral"):
|
||||
return (
|
||||
litellm.MistralTextCompletionConfig().get_supported_openai_params()
|
||||
)
|
||||
if model.startswith("claude"):
|
||||
return litellm.VertexAIAnthropicConfig().get_supported_openai_params()
|
||||
return litellm.VertexAIConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "vertex_ai_beta":
|
||||
if request_type == "chat_completion":
|
||||
return litellm.VertexGeminiConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
elif custom_llm_provider == "aleph_alpha":
|
||||
return [
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"top_p",
|
||||
"temperature",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"n",
|
||||
"stop",
|
||||
]
|
||||
elif custom_llm_provider == "cloudflare":
|
||||
return ["max_tokens", "stream"]
|
||||
elif custom_llm_provider == "nlp_cloud":
|
||||
return [
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"n",
|
||||
"stop",
|
||||
]
|
||||
elif custom_llm_provider == "petals":
|
||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
||||
elif custom_llm_provider == "deepinfra":
|
||||
return litellm.DeepInfraConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "perplexity":
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
]
|
||||
elif custom_llm_provider == "anyscale":
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
"max_tokens",
|
||||
"stop",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]
|
||||
elif custom_llm_provider == "watsonx":
|
||||
return litellm.IBMWatsonXAIConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "custom_openai" or "text-completion-openai":
|
||||
return [
|
||||
"functions",
|
||||
"function_call",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"max_retries",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"extra_headers",
|
||||
]
|
||||
return None
|
||||
|
||||
|
||||
def _count_characters(text: str) -> int:
|
||||
# Remove white spaces and count characters
|
||||
filtered_text = "".join(char for char in text if not char.isspace())
|
||||
|
@ -8640,3 +8347,47 @@ def add_dummy_tool(custom_llm_provider: str) -> List[ChatCompletionToolParam]:
|
|||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionAudioObject,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionTextObject,
|
||||
ChatCompletionUserMessage,
|
||||
OpenAIMessageContent,
|
||||
ValidUserMessageContentTypes,
|
||||
)
|
||||
|
||||
|
||||
def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
|
||||
"""
|
||||
Ensures all user messages are valid OpenAI chat completion messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
message_content_type: Type to validate content against
|
||||
|
||||
Returns:
|
||||
List[dict]: The validated messages
|
||||
|
||||
Raises:
|
||||
ValueError: If any message is invalid
|
||||
"""
|
||||
for idx, m in enumerate(messages):
|
||||
try:
|
||||
if m["role"] == "user":
|
||||
user_content = m.get("content")
|
||||
if user_content is not None:
|
||||
if isinstance(user_content, str):
|
||||
continue
|
||||
elif isinstance(user_content, list):
|
||||
for item in user_content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") not in ValidUserMessageContentTypes:
|
||||
raise Exception("invalid content type")
|
||||
except Exception:
|
||||
raise Exception(
|
||||
f"Invalid user message={m} at index {idx}. Please ensure all user messages are valid OpenAI chat completion messages."
|
||||
)
|
||||
|
||||
return messages
|
||||
|
|
|
@ -233,7 +233,7 @@ def test_throws_if_api_base_or_api_key_not_set_without_databricks_sdk(
|
|||
with pytest.raises(BadRequestError) as exc:
|
||||
litellm.completion(
|
||||
model="databricks/dbrx-instruct-071224",
|
||||
messages={"role": "user", "content": "How are you?"},
|
||||
messages=[{"role": "user", "content": "How are you?"}],
|
||||
)
|
||||
assert err_msg in str(exc)
|
||||
|
||||
|
|
|
@ -905,3 +905,19 @@ def test_vertex_schema_field():
|
|||
"$schema"
|
||||
not in optional_params["tools"][0]["function_declarations"][0]["parameters"]
|
||||
)
|
||||
|
||||
|
||||
def test_watsonx_tool_choice():
|
||||
optional_params = get_optional_params(
|
||||
model="gemini-1.5-pro", custom_llm_provider="watsonx", tool_choice="auto"
|
||||
)
|
||||
print(optional_params)
|
||||
assert optional_params["tool_choice_options"] == "auto"
|
||||
|
||||
|
||||
def test_watsonx_text_top_k():
|
||||
optional_params = get_optional_params(
|
||||
model="gemini-1.5-pro", custom_llm_provider="watsonx_text", top_k=10
|
||||
)
|
||||
print(optional_params)
|
||||
assert optional_params["top_k"] == 10
|
||||
|
|
|
@ -203,7 +203,7 @@ def create_async_task(**completion_kwargs):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("stream", [False, True])
|
||||
@pytest.mark.flaky(retries=6, delay=1)
|
||||
@pytest.mark.flaky(retries=12, delay=2)
|
||||
async def test_langfuse_logging_without_request_response(stream, langfuse_client):
|
||||
try:
|
||||
import uuid
|
||||
|
@ -232,6 +232,12 @@ async def test_langfuse_logging_without_request_response(stream, langfuse_client
|
|||
|
||||
_trace_data = trace.data
|
||||
|
||||
if (
|
||||
len(_trace_data) == 0
|
||||
): # prevent infrequent list index out of range error from langfuse api
|
||||
return
|
||||
|
||||
print(f"_trace_data: {_trace_data}")
|
||||
assert _trace_data[0].input == {
|
||||
"messages": [{"content": "redacted-by-litellm", "role": "user"}]
|
||||
}
|
||||
|
@ -256,7 +262,7 @@ audio_file = open(file_path, "rb")
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
@pytest.mark.flaky(retries=12, delay=2)
|
||||
async def test_langfuse_logging_audio_transcriptions(langfuse_client):
|
||||
"""
|
||||
Test that creates a trace with masked input and output
|
||||
|
@ -291,7 +297,7 @@ async def test_langfuse_logging_audio_transcriptions(langfuse_client):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=5, delay=1)
|
||||
@pytest.mark.flaky(retries=12, delay=2)
|
||||
async def test_langfuse_masked_input_output(langfuse_client):
|
||||
"""
|
||||
Test that creates a trace with masked input and output
|
||||
|
@ -344,7 +350,7 @@ async def test_langfuse_masked_input_output(langfuse_client):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
@pytest.mark.flaky(retries=12, delay=2)
|
||||
async def test_aaalangfuse_logging_metadata(langfuse_client):
|
||||
"""
|
||||
Test that creates multiple traces, with a varying number of generations and sets various metadata fields
|
||||
|
|
|
@ -775,7 +775,7 @@ def test_litellm_predibase_exception():
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider", ["predibase", "vertex_ai_beta", "anthropic", "databricks"]
|
||||
"provider", ["predibase", "vertex_ai_beta", "anthropic", "databricks", "watsonx"]
|
||||
)
|
||||
def test_exception_mapping(provider):
|
||||
"""
|
||||
|
|
|
@ -12,7 +12,7 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
import litellm
|
||||
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||
|
||||
|
@ -619,3 +619,62 @@ def test_passing_tool_result_as_list(model):
|
|||
|
||||
if model == "claude-3-5-sonnet-20241022":
|
||||
assert resp.usage.prompt_tokens_details.cached_tokens > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_watsonx_tool_choice(sync_mode):
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||
import json
|
||||
from litellm import acompletion, completion
|
||||
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [{"role": "user", "content": "What is the weather in San Francisco?"}]
|
||||
|
||||
client = HTTPHandler() if sync_mode else AsyncHTTPHandler()
|
||||
with patch.object(client, "post", return_value=MagicMock()) as mock_completion:
|
||||
|
||||
if sync_mode:
|
||||
resp = completion(
|
||||
model="watsonx/meta-llama/llama-3-1-8b-instruct",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
resp = await acompletion(
|
||||
model="watsonx/meta-llama/llama-3-1-8b-instruct",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
client=client,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
print(resp)
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
print(mock_completion.call_args.kwargs)
|
||||
json_data = json.loads(mock_completion.call_args.kwargs["data"])
|
||||
json_data["tool_choice_options"] == "auto"
|
||||
|
|
|
@ -1917,25 +1917,31 @@ def test_completion_sagemaker_stream():
|
|||
|
||||
|
||||
@pytest.mark.skip(reason="Account deleted by IBM.")
|
||||
def test_completion_watsonx_stream():
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_watsonx_stream():
|
||||
litellm.set_verbose = True
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
try:
|
||||
response = completion(
|
||||
model="watsonx/ibm/granite-13b-chat-v2",
|
||||
response = await acompletion(
|
||||
model="watsonx/meta-llama/llama-3-1-8b-instruct",
|
||||
messages=messages,
|
||||
temperature=0.5,
|
||||
max_tokens=20,
|
||||
stream=True,
|
||||
# client=client
|
||||
)
|
||||
complete_response = ""
|
||||
has_finish_reason = False
|
||||
# Add any assertions here to check the response
|
||||
for idx, chunk in enumerate(response):
|
||||
idx = 0
|
||||
async for chunk in response:
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
has_finish_reason = finished
|
||||
if finished:
|
||||
break
|
||||
complete_response += chunk
|
||||
idx += 1
|
||||
if has_finish_reason is False:
|
||||
raise Exception("finish reason not set for last chunk")
|
||||
if complete_response.strip() == "":
|
||||
|
|
|
@ -891,3 +891,55 @@ def test_is_base64_encoded_2():
|
|||
)
|
||||
|
||||
assert is_base64_encoded(s="Dog") is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"messages, expected_bool",
|
||||
[
|
||||
([{"role": "user", "content": "hi"}], True),
|
||||
([{"role": "user", "content": [{"type": "text", "text": "hi"}]}], True),
|
||||
(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "url": "https://example.com/image.png"}
|
||||
],
|
||||
}
|
||||
],
|
||||
True,
|
||||
),
|
||||
(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "hi"},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "1234",
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_chat_completion_user_messages(messages, expected_bool):
|
||||
from litellm.utils import validate_chat_completion_user_messages
|
||||
|
||||
if expected_bool:
|
||||
## Valid message
|
||||
validate_chat_completion_user_messages(messages=messages)
|
||||
else:
|
||||
## Invalid message
|
||||
with pytest.raises(Exception):
|
||||
validate_chat_completion_user_messages(messages=messages)
|
||||
|
|
|
@ -93,7 +93,9 @@ async def test_datadog_llm_obs_logging():
|
|||
|
||||
for _ in range(2):
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-4o", messages=["Hello testing dd llm obs!"], mock_response="hi"
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": "Hello testing dd llm obs!"}],
|
||||
mock_response="hi",
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue