Litellm dev 01 06 2025 p2 (#7597)

* test(test_amazing_vertex_completion.py): fix test

* test: initial working code gecko test

* fix(vertex_ai_non_gemini.py): support vertex ai code gecko fake streaming

Fixes https://github.com/BerriAI/litellm/issues/7360

* test(test_get_model_info.py): add test for getting custom provider model info

Covers https://github.com/BerriAI/litellm/issues/7575

* fix(utils.py): fix get_provider_model_info check

Handle custom llm provider scenario

Fixes https://github.com/
BerriAI/litellm/issues/7575
This commit is contained in:
Krish Dholakia 2025-01-06 21:04:49 -08:00 committed by GitHub
parent b397dc1497
commit 0c3fef24cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 76 additions and 12 deletions

View file

@ -7,6 +7,7 @@ import httpx
import litellm import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.bedrock.common_utils import ModelResponseIterator
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
from litellm.types.llms.vertex_ai import * from litellm.types.llms.vertex_ai import *
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
@ -197,6 +198,7 @@ def completion( # noqa: PLR0915
client_options = { client_options = {
"api_endpoint": f"{vertex_location}-aiplatform.googleapis.com" "api_endpoint": f"{vertex_location}-aiplatform.googleapis.com"
} }
fake_stream = False
if ( if (
model in litellm.vertex_language_models model in litellm.vertex_language_models
or model in litellm.vertex_vision_models or model in litellm.vertex_vision_models
@ -220,6 +222,7 @@ def completion( # noqa: PLR0915
) )
mode = "text" mode = "text"
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
fake_stream = True
elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models
llm_model = _vertex_llm_model_object or CodeChatModel.from_pretrained(model) llm_model = _vertex_llm_model_object or CodeChatModel.from_pretrained(model)
mode = "chat" mode = "chat"
@ -275,17 +278,22 @@ def completion( # noqa: PLR0915
return async_completion(**data) return async_completion(**data)
completion_response = None completion_response = None
stream = optional_params.pop(
"stream", None
) # See note above on handling streaming for vertex ai
if mode == "chat": if mode == "chat":
chat = llm_model.start_chat() chat = llm_model.start_chat()
request_str += "chat = llm_model.start_chat()\n" request_str += "chat = llm_model.start_chat()\n"
if "stream" in optional_params and optional_params["stream"] is True: if fake_stream is not True and stream is True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error, # NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request # we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format # after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
optional_params.pop( optional_params.pop(
"stream", None "stream", None
) # vertex ai raises an error when passing stream in optional params ) # vertex ai raises an error when passing stream in optional params
request_str += ( request_str += (
f"chat.send_message_streaming({prompt}, **{optional_params})\n" f"chat.send_message_streaming({prompt}, **{optional_params})\n"
) )
@ -298,6 +306,7 @@ def completion( # noqa: PLR0915
"request_str": request_str, "request_str": request_str,
}, },
) )
model_response = chat.send_message_streaming(prompt, **optional_params) model_response = chat.send_message_streaming(prompt, **optional_params)
return model_response return model_response
@ -314,10 +323,8 @@ def completion( # noqa: PLR0915
) )
completion_response = chat.send_message(prompt, **optional_params).text completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text": elif mode == "text":
if "stream" in optional_params and optional_params["stream"] is True:
optional_params.pop( if fake_stream is not True and stream is True:
"stream", None
) # See note above on handling streaming for vertex ai
request_str += ( request_str += (
f"llm_model.predict_streaming({prompt}, **{optional_params})\n" f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
) )
@ -384,7 +391,7 @@ def completion( # noqa: PLR0915
and "\nOutput:\n" in completion_response and "\nOutput:\n" in completion_response
): ):
completion_response = completion_response.split("\nOutput:\n", 1)[1] completion_response = completion_response.split("\nOutput:\n", 1)[1]
if "stream" in optional_params and optional_params["stream"] is True: if stream is True:
response = TextStreamer(completion_response) response = TextStreamer(completion_response)
return response return response
elif mode == "private": elif mode == "private":
@ -413,7 +420,7 @@ def completion( # noqa: PLR0915
and "\nOutput:\n" in completion_response and "\nOutput:\n" in completion_response
): ):
completion_response = completion_response.split("\nOutput:\n", 1)[1] completion_response = completion_response.split("\nOutput:\n", 1)[1]
if "stream" in optional_params and optional_params["stream"] is True: if stream is True:
response = TextStreamer(completion_response) response = TextStreamer(completion_response)
return response return response
@ -465,6 +472,9 @@ def completion( # noqa: PLR0915
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
setattr(model_response, "usage", usage) setattr(model_response, "usage", usage)
if fake_stream is True and stream is True:
return ModelResponseIterator(model_response)
return model_response return model_response
except Exception as e: except Exception as e:
if isinstance(e, VertexAIError): if isinstance(e, VertexAIError):

View file

@ -4224,6 +4224,7 @@ def _get_model_info_helper( # noqa: PLR0915
_model_info: Optional[Dict[str, Any]] = None _model_info: Optional[Dict[str, Any]] = None
key: Optional[str] = None key: Optional[str] = None
provider_config: Optional[BaseLLMModelInfo] = None provider_config: Optional[BaseLLMModelInfo] = None
if combined_model_name in litellm.model_cost: if combined_model_name in litellm.model_cost:
key = combined_model_name key = combined_model_name
_model_info = _get_model_info_from_model_cost(key=key) _model_info = _get_model_info_from_model_cost(key=key)
@ -4263,7 +4264,10 @@ def _get_model_info_helper( # noqa: PLR0915
): ):
_model_info = None _model_info = None
if custom_llm_provider: if custom_llm_provider and custom_llm_provider in [
provider.value for provider in LlmProviders
]:
# Check if the provider string exists in LlmProviders enum
provider_config = ProviderConfigManager.get_provider_model_info( provider_config = ProviderConfigManager.get_provider_model_info(
model=model, provider=LlmProviders(custom_llm_provider) model=model, provider=LlmProviders(custom_llm_provider)
) )

View file

@ -930,7 +930,7 @@ from test_completion import response_format_tests
"vertex_ai/mistral-large@2407", "vertex_ai/mistral-large@2407",
"vertex_ai/mistral-nemo@2407", "vertex_ai/mistral-nemo@2407",
"vertex_ai/codestral@2405", "vertex_ai/codestral@2405",
"vertex_ai/meta/llama3-405b-instruct-maas", # "vertex_ai/meta/llama3-405b-instruct-maas",
], # ], #
) # "vertex_ai", ) # "vertex_ai",
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -960,7 +960,6 @@ async def test_partner_models_httpx(model, sync_mode):
"model": model, "model": model,
"messages": messages, "messages": messages,
"timeout": 10, "timeout": 10,
"mock_response": "Hello, how are you?",
} }
if sync_mode: if sync_mode:
response = litellm.completion(**data) response = litellm.completion(**data)
@ -993,7 +992,8 @@ async def test_partner_models_httpx(model, sync_mode):
"model", "model",
[ [
"vertex_ai/mistral-large@2407", "vertex_ai/mistral-large@2407",
"vertex_ai/meta/llama3-405b-instruct-maas", # "vertex_ai/meta/llama3-405b-instruct-maas",
"vertex_ai/codestral@2405",
], # ], #
) # "vertex_ai", ) # "vertex_ai",
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -1023,7 +1023,6 @@ async def test_partner_models_httpx_streaming(model, sync_mode):
"model": model, "model": model,
"messages": messages, "messages": messages,
"stream": True, "stream": True,
"mock_response": "Hello, how are you?",
} }
if sync_mode: if sync_mode:
response = litellm.completion(**data) response = litellm.completion(**data)
@ -3193,3 +3192,16 @@ async def test_vertexai_model_garden_model_completion(
assert response.usage.completion_tokens == 109 assert response.usage.completion_tokens == 109
assert response.usage.prompt_tokens == 63 assert response.usage.prompt_tokens == 63
assert response.usage.total_tokens == 172 assert response.usage.total_tokens == 172
def test_vertexai_code_gecko():
litellm.set_verbose = True
load_vertex_ai_credentials()
response = completion(
model="vertex_ai/code-gecko@002",
messages=[{"role": "user", "content": "Hello world!"}],
stream=True,
)
for chunk in response:
print(chunk)

View file

@ -247,3 +247,41 @@ def test_model_info_bedrock_converse_enforcement(monkeypatch):
) )
except FileNotFoundError as e: except FileNotFoundError as e:
pytest.skip("whitelisted_bedrock_models.txt not found") pytest.skip("whitelisted_bedrock_models.txt not found")
def test_get_model_info_custom_provider():
# Custom provider example copied from https://docs.litellm.ai/docs/providers/custom_llm_server:
import litellm
from litellm import CustomLLM, completion, get_llm_provider
class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [ # 👈 KEY STEP - REGISTER HANDLER
{"provider": "my-custom-llm", "custom_handler": my_custom_llm}
]
resp = completion(
model="my-custom-llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
)
assert resp.choices[0].message.content == "Hi!"
# Register model info
model_info = {"my-custom-llm/my-fake-model": {"max_tokens": 2048}}
litellm.register_model(model_info)
# Get registered model info
from litellm import get_model_info
get_model_info(
model="my-custom-llm/my-fake-model"
) # 💥 "Exception: This model isn't mapped yet." in v1.56.10