diff --git a/docs/my-website/docs/providers/watsonx.md b/docs/my-website/docs/providers/watsonx.md index 7a42a54edd..8665611fa7 100644 --- a/docs/my-website/docs/providers/watsonx.md +++ b/docs/my-website/docs/providers/watsonx.md @@ -31,14 +31,16 @@ from litellm import completion os.environ["WATSONX_URL"] = "" os.environ["WATSONX_APIKEY"] = "" +## Call WATSONX `/text/chat` endpoint - supports function calling response = completion( - model="watsonx/ibm/granite-13b-chat-v2", + model="watsonx/meta-llama/llama-3-1-8b-instruct", messages=[{ "content": "what is your favorite colour?","role": "user"}], project_id="" # or pass with os.environ["WATSONX_PROJECT_ID"] ) +## Call WATSONX `/text/generation` endpoint - not all models support /chat route. response = completion( - model="watsonx/meta-llama/llama-3-8b-instruct", + model="watsonx/ibm/granite-13b-chat-v2", messages=[{ "content": "what is your favorite colour?","role": "user"}], project_id="" ) @@ -54,7 +56,7 @@ os.environ["WATSONX_APIKEY"] = "" os.environ["WATSONX_PROJECT_ID"] = "" response = completion( - model="watsonx/ibm/granite-13b-chat-v2", + model="watsonx/meta-llama/llama-3-1-8b-instruct", messages=[{ "content": "what is your favorite colour?","role": "user"}], stream=True ) diff --git a/litellm/litellm_core_utils/exception_mapping_utils.py b/litellm/litellm_core_utils/exception_mapping_utils.py index c4228f0527..0ca2f4e262 100644 --- a/litellm/litellm_core_utils/exception_mapping_utils.py +++ b/litellm/litellm_core_utils/exception_mapping_utils.py @@ -656,6 +656,13 @@ def exception_type( # type: ignore # noqa: PLR0915 llm_provider=custom_llm_provider, model=model, ) + elif "model_no_support_for_function" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"{custom_llm_provider}Exception - Use 'watsonx_text' route instead. IBM WatsonX does not support `/text/chat` endpoint. - {error_str}", + llm_provider=custom_llm_provider, + model=model, + ) elif hasattr(original_exception, "status_code"): if original_exception.status_code == 500: exception_mapping_worked = True diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index e6bdbf7451..d534d4da3b 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -2584,6 +2584,15 @@ class StandardLoggingPayloadSetup: clean_metadata["user_api_key_hash"] = metadata.get( "user_api_key" ) # this is the hash + _potential_requester_metadata = metadata.get( + "metadata", None + ) # check if user passed metadata in the sdk request - e.g. metadata for langsmith logging - https://docs.litellm.ai/docs/observability/langsmith_integration#set-langsmith-fields + if ( + clean_metadata["requester_metadata"] is None + and _potential_requester_metadata is not None + and isinstance(_potential_requester_metadata, dict) + ): + clean_metadata["requester_metadata"] = _potential_requester_metadata return clean_metadata @staticmethod diff --git a/litellm/llms/watsonx/common_utils.py b/litellm/llms/watsonx/common_utils.py index e8ddc5f328..b270f2d82b 100644 --- a/litellm/llms/watsonx/common_utils.py +++ b/litellm/llms/watsonx/common_utils.py @@ -23,6 +23,12 @@ class WatsonXAIError(BaseLLMException): iam_token_cache = InMemoryCache() +def get_watsonx_iam_url(): + return ( + get_secret_str("WATSONX_IAM_URL") or "https://iam.cloud.ibm.com/identity/token" + ) + + def generate_iam_token(api_key=None, **params) -> str: result: Optional[str] = iam_token_cache.get_cache(api_key) # type: ignore @@ -38,15 +44,14 @@ def generate_iam_token(api_key=None, **params) -> str: "grant_type": "urn:ibm:params:oauth:grant-type:apikey", "apikey": api_key, } + iam_token_url = get_watsonx_iam_url() verbose_logger.debug( "calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s", - "https://iam.cloud.ibm.com/identity/token", + iam_token_url, headers, data, ) - response = httpx.post( - "https://iam.cloud.ibm.com/identity/token", data=data, headers=headers - ) + response = httpx.post(iam_token_url, data=data, headers=headers) response.raise_for_status() json_data = response.json() diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index a66057ae30..a5fbf8c6a0 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -9,4 +9,7 @@ model_list: litellm_params: model: "*" model_info: - access_groups: ["default"] \ No newline at end of file + access_groups: ["default"] + +litellm_settings: + success_callback: ["langsmith"] \ No newline at end of file diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index a73a227db6..d7eb306bfb 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -3977,10 +3977,11 @@ def test_completion_deepseek(): @pytest.mark.skip(reason="Account deleted by IBM.") -def test_completion_watsonx(): +def test_completion_watsonx_error(): litellm.set_verbose = True model_name = "watsonx/ibm/granite-13b-chat-v2" - try: + + with pytest.raises(litellm.BadRequestError) as e: response = completion( model=model_name, messages=messages, @@ -3989,12 +3990,8 @@ def test_completion_watsonx(): ) # Add any assertions here to check the response print(response) - except litellm.APIError as e: - pass - except litellm.RateLimitError as e: - pass - except Exception as e: - pytest.fail(f"Error occurred: {e}") + + assert "use 'watsonx_text' route instead" in str(e).lower() @pytest.mark.skip(reason="Skip test. account deleted.") diff --git a/tests/logging_callback_tests/test_standard_logging_payload.py b/tests/logging_callback_tests/test_standard_logging_payload.py index 29dd1454bd..5bda02c562 100644 --- a/tests/logging_callback_tests/test_standard_logging_payload.py +++ b/tests/logging_callback_tests/test_standard_logging_payload.py @@ -448,3 +448,27 @@ def test_get_response_time(): # For streaming, should return completion_start_time - start_time assert response_time == 2.0 + + +@pytest.mark.parametrize( + "metadata, expected_requester_metadata", + [ + ({"metadata": {"test": "test2"}}, {"test": "test2"}), + ({"metadata": {"test": "test2"}, "model_id": "test-model"}, {"test": "test2"}), + ( + { + "metadata": { + "test": "test2", + }, + "model_id": "test-model", + "requester_metadata": {"test": "test2"}, + }, + {"test": "test2"}, + ), + ], +) +def test_standard_logging_metadata_requester_metadata( + metadata, expected_requester_metadata +): + result = StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata) + assert result["requester_metadata"] == expected_requester_metadata