forked from phoenix/litellm-mirror
fix(vertex_httpx.py): support async completion calls
This commit is contained in:
parent
3955b058ed
commit
995631bd39
3 changed files with 142 additions and 22 deletions
|
@ -275,13 +275,89 @@ class VertexLLM(BaseLLM):
|
|||
|
||||
async def async_streaming(
|
||||
self,
|
||||
):
|
||||
pass
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
):
|
||||
pass
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await client.post(api_base, headers=headers, json=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return self._process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
|
@ -344,6 +420,27 @@ class VertexLLM(BaseLLM):
|
|||
},
|
||||
)
|
||||
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data, # type: ignore
|
||||
api_base=url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client, # type: ignore
|
||||
)
|
||||
|
||||
## SYNC STREAMING CALL ##
|
||||
if stream is not None and stream is True:
|
||||
streaming_response = CustomStreamWrapper(
|
||||
|
@ -359,7 +456,7 @@ class VertexLLM(BaseLLM):
|
|||
logging_obj=logging_obj,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
|
|
|
@ -329,6 +329,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "ollama_chat"
|
||||
or custom_llm_provider == "replicate"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
or custom_llm_provider == "vertex_ai_beta"
|
||||
or custom_llm_provider == "gemini"
|
||||
or custom_llm_provider == "sagemaker"
|
||||
or custom_llm_provider == "anthropic"
|
||||
|
|
|
@ -503,28 +503,50 @@ async def test_async_vertexai_streaming_response():
|
|||
# asyncio.run(test_async_vertexai_streaming_response())
|
||||
|
||||
|
||||
def test_gemini_pro_vision():
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai", "vertex_ai_beta"])
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_pro_vision(provider, sync_mode):
|
||||
try:
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
litellm.num_retries = 3
|
||||
resp = litellm.completion(
|
||||
model="vertex_ai/gemini-1.5-flash-preview-0514",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Whats in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
|
||||
if sync_mode:
|
||||
resp = litellm.completion(
|
||||
model="{}/gemini-1.5-flash-preview-0514".format(provider),
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Whats in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
else:
|
||||
resp = await litellm.acompletion(
|
||||
model="{}/gemini-1.5-flash-preview-0514".format(provider),
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Whats in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
print(resp)
|
||||
|
||||
prompt_tokens = resp.usage.prompt_tokens
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue