Litellm dev 12 13 2024 p1 (#7219)

* fix(litellm_logging.py): pass user metadata to langsmith on sdk calls

* fix(litellm_logging.py): pass nested user metadata to logging integration - e.g. langsmith

* fix(exception_mapping_utils.py): catch and clarify watsonx `/text/chat` endpoint not supported error message.

Closes https://github.com/BerriAI/litellm/issues/7213

* fix(watsonx/common_utils.py): accept new 'WATSONX_IAM_URL' env var

allows user to use local watsonx

Fixes https://github.com/BerriAI/litellm/issues/4991

* fix(litellm_logging.py): cleanup unused function

* test: skip bad ibm test
This commit is contained in:
Krish Dholakia 2024-12-13 19:01:28 -08:00 committed by GitHub
parent 30e147a315
commit b150faff90
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 63 additions and 16 deletions

View file

@ -31,14 +31,16 @@ from litellm import completion
os.environ["WATSONX_URL"] = "" os.environ["WATSONX_URL"] = ""
os.environ["WATSONX_APIKEY"] = "" os.environ["WATSONX_APIKEY"] = ""
## Call WATSONX `/text/chat` endpoint - supports function calling
response = completion( response = completion(
model="watsonx/ibm/granite-13b-chat-v2", model="watsonx/meta-llama/llama-3-1-8b-instruct",
messages=[{ "content": "what is your favorite colour?","role": "user"}], messages=[{ "content": "what is your favorite colour?","role": "user"}],
project_id="<my-project-id>" # or pass with os.environ["WATSONX_PROJECT_ID"] project_id="<my-project-id>" # or pass with os.environ["WATSONX_PROJECT_ID"]
) )
## Call WATSONX `/text/generation` endpoint - not all models support /chat route.
response = completion( response = completion(
model="watsonx/meta-llama/llama-3-8b-instruct", model="watsonx/ibm/granite-13b-chat-v2",
messages=[{ "content": "what is your favorite colour?","role": "user"}], messages=[{ "content": "what is your favorite colour?","role": "user"}],
project_id="<my-project-id>" project_id="<my-project-id>"
) )
@ -54,7 +56,7 @@ os.environ["WATSONX_APIKEY"] = ""
os.environ["WATSONX_PROJECT_ID"] = "" os.environ["WATSONX_PROJECT_ID"] = ""
response = completion( response = completion(
model="watsonx/ibm/granite-13b-chat-v2", model="watsonx/meta-llama/llama-3-1-8b-instruct",
messages=[{ "content": "what is your favorite colour?","role": "user"}], messages=[{ "content": "what is your favorite colour?","role": "user"}],
stream=True stream=True
) )

View file

@ -656,6 +656,13 @@ def exception_type( # type: ignore # noqa: PLR0915
llm_provider=custom_llm_provider, llm_provider=custom_llm_provider,
model=model, model=model,
) )
elif "model_no_support_for_function" in error_str:
exception_mapping_worked = True
raise BadRequestError(
message=f"{custom_llm_provider}Exception - Use 'watsonx_text' route instead. IBM WatsonX does not support `/text/chat` endpoint. - {error_str}",
llm_provider=custom_llm_provider,
model=model,
)
elif hasattr(original_exception, "status_code"): elif hasattr(original_exception, "status_code"):
if original_exception.status_code == 500: if original_exception.status_code == 500:
exception_mapping_worked = True exception_mapping_worked = True

View file

@ -2584,6 +2584,15 @@ class StandardLoggingPayloadSetup:
clean_metadata["user_api_key_hash"] = metadata.get( clean_metadata["user_api_key_hash"] = metadata.get(
"user_api_key" "user_api_key"
) # this is the hash ) # this is the hash
_potential_requester_metadata = metadata.get(
"metadata", None
) # check if user passed metadata in the sdk request - e.g. metadata for langsmith logging - https://docs.litellm.ai/docs/observability/langsmith_integration#set-langsmith-fields
if (
clean_metadata["requester_metadata"] is None
and _potential_requester_metadata is not None
and isinstance(_potential_requester_metadata, dict)
):
clean_metadata["requester_metadata"] = _potential_requester_metadata
return clean_metadata return clean_metadata
@staticmethod @staticmethod

View file

@ -23,6 +23,12 @@ class WatsonXAIError(BaseLLMException):
iam_token_cache = InMemoryCache() iam_token_cache = InMemoryCache()
def get_watsonx_iam_url():
return (
get_secret_str("WATSONX_IAM_URL") or "https://iam.cloud.ibm.com/identity/token"
)
def generate_iam_token(api_key=None, **params) -> str: def generate_iam_token(api_key=None, **params) -> str:
result: Optional[str] = iam_token_cache.get_cache(api_key) # type: ignore result: Optional[str] = iam_token_cache.get_cache(api_key) # type: ignore
@ -38,15 +44,14 @@ def generate_iam_token(api_key=None, **params) -> str:
"grant_type": "urn:ibm:params:oauth:grant-type:apikey", "grant_type": "urn:ibm:params:oauth:grant-type:apikey",
"apikey": api_key, "apikey": api_key,
} }
iam_token_url = get_watsonx_iam_url()
verbose_logger.debug( verbose_logger.debug(
"calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s", "calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s",
"https://iam.cloud.ibm.com/identity/token", iam_token_url,
headers, headers,
data, data,
) )
response = httpx.post( response = httpx.post(iam_token_url, data=data, headers=headers)
"https://iam.cloud.ibm.com/identity/token", data=data, headers=headers
)
response.raise_for_status() response.raise_for_status()
json_data = response.json() json_data = response.json()

View file

@ -10,3 +10,6 @@ model_list:
model: "*" model: "*"
model_info: model_info:
access_groups: ["default"] access_groups: ["default"]
litellm_settings:
success_callback: ["langsmith"]

View file

@ -3977,10 +3977,11 @@ def test_completion_deepseek():
@pytest.mark.skip(reason="Account deleted by IBM.") @pytest.mark.skip(reason="Account deleted by IBM.")
def test_completion_watsonx(): def test_completion_watsonx_error():
litellm.set_verbose = True litellm.set_verbose = True
model_name = "watsonx/ibm/granite-13b-chat-v2" model_name = "watsonx/ibm/granite-13b-chat-v2"
try:
with pytest.raises(litellm.BadRequestError) as e:
response = completion( response = completion(
model=model_name, model=model_name,
messages=messages, messages=messages,
@ -3989,12 +3990,8 @@ def test_completion_watsonx():
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
except litellm.APIError as e:
pass assert "use 'watsonx_text' route instead" in str(e).lower()
except litellm.RateLimitError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="Skip test. account deleted.") @pytest.mark.skip(reason="Skip test. account deleted.")

View file

@ -448,3 +448,27 @@ def test_get_response_time():
# For streaming, should return completion_start_time - start_time # For streaming, should return completion_start_time - start_time
assert response_time == 2.0 assert response_time == 2.0
@pytest.mark.parametrize(
"metadata, expected_requester_metadata",
[
({"metadata": {"test": "test2"}}, {"test": "test2"}),
({"metadata": {"test": "test2"}, "model_id": "test-model"}, {"test": "test2"}),
(
{
"metadata": {
"test": "test2",
},
"model_id": "test-model",
"requester_metadata": {"test": "test2"},
},
{"test": "test2"},
),
],
)
def test_standard_logging_metadata_requester_metadata(
metadata, expected_requester_metadata
):
result = StandardLoggingPayloadSetup.get_standard_logging_metadata(metadata)
assert result["requester_metadata"] == expected_requester_metadata