fix(vertex_httpx.py): support async completion calls

This commit is contained in:
Krrish Dholakia 2024-06-12 20:15:03 -07:00
parent 3955b058ed
commit 995631bd39
3 changed files with 142 additions and 22 deletions

View file

@ -275,13 +275,89 @@ class VertexLLM(BaseLLM):
async def async_streaming(
self,
):
pass
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper:
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
client=client,
api_base=api_base,
headers=headers,
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="vertex_ai_beta",
logging_obj=logging_obj,
)
return streaming_response
async def async_completion(
self,
):
pass
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, json=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")
return self._process_response(
model=model,
response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def completion(
self,
@ -344,6 +420,27 @@ class VertexLLM(BaseLLM):
},
)
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data, # type: ignore
api_base=url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client, # type: ignore
)
## SYNC STREAMING CALL ##
if stream is not None and stream is True:
streaming_response = CustomStreamWrapper(
@ -359,7 +456,7 @@ class VertexLLM(BaseLLM):
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="bedrock",
custom_llm_provider="vertex_ai_beta",
logging_obj=logging_obj,
)

View file

@ -329,6 +329,7 @@ async def acompletion(
or custom_llm_provider == "ollama_chat"
or custom_llm_provider == "replicate"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "vertex_ai_beta"
or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"

View file

@ -503,28 +503,50 @@ async def test_async_vertexai_streaming_response():
# asyncio.run(test_async_vertexai_streaming_response())
def test_gemini_pro_vision():
@pytest.mark.parametrize("provider", ["vertex_ai", "vertex_ai_beta"])
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_gemini_pro_vision(provider, sync_mode):
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
litellm.num_retries = 3
resp = litellm.completion(
model="vertex_ai/gemini-1.5-flash-preview-0514",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
if sync_mode:
resp = litellm.completion(
model="{}/gemini-1.5-flash-preview-0514".format(provider),
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
},
},
},
],
}
],
)
],
}
],
)
else:
resp = await litellm.acompletion(
model="{}/gemini-1.5-flash-preview-0514".format(provider),
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
},
},
],
}
],
)
print(resp)
prompt_tokens = resp.usage.prompt_tokens