diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 370668afb0..e168c23244 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -77,7 +77,7 @@ def test_completion_claude_stream(): def test_completion_cohere(): try: response = completion( - model="command-nightly", messages=messages, max_tokens=100 + model="command-nightly", messages=messages, max_tokens=100, logit_bias={40: 10} ) # Add any assertions here to check the response print(response) @@ -91,7 +91,6 @@ def test_completion_cohere(): except Exception as e: pytest.fail(f"Error occurred: {e}") - def test_completion_cohere_stream(): try: messages = [ diff --git a/litellm/utils.py b/litellm/utils.py index 290e64ddba..5346ce62a1 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -452,6 +452,8 @@ def get_optional_params( optional_params["temperature"] = temperature if max_tokens != float("inf"): optional_params["max_tokens"] = max_tokens + if logit_bias != {}: + optional_params["logit_bias"] = logit_bias return optional_params elif custom_llm_provider == "replicate": # any replicate models