forked from phoenix/litellm-mirror
fix(vertex_httpx.py): support async completion calls
This commit is contained in:
parent
3955b058ed
commit
995631bd39
3 changed files with 142 additions and 22 deletions
|
@ -275,13 +275,89 @@ class VertexLLM(BaseLLM):
|
||||||
|
|
||||||
async def async_streaming(
|
async def async_streaming(
|
||||||
self,
|
self,
|
||||||
):
|
model: str,
|
||||||
pass
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_call,
|
||||||
|
client=client,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="vertex_ai_beta",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
async def async_completion(
|
async def async_completion(
|
||||||
self,
|
self,
|
||||||
):
|
model: str,
|
||||||
pass
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
data: str,
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
if client is None:
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post(api_base, headers=headers, json=data) # type: ignore
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
error_code = err.response.status_code
|
||||||
|
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return self._process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -344,6 +420,27 @@ class VertexLLM(BaseLLM):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
if acompletion:
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data, # type: ignore
|
||||||
|
api_base=url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
## SYNC STREAMING CALL ##
|
## SYNC STREAMING CALL ##
|
||||||
if stream is not None and stream is True:
|
if stream is not None and stream is True:
|
||||||
streaming_response = CustomStreamWrapper(
|
streaming_response = CustomStreamWrapper(
|
||||||
|
@ -359,7 +456,7 @@ class VertexLLM(BaseLLM):
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
),
|
),
|
||||||
model=model,
|
model=model,
|
||||||
custom_llm_provider="bedrock",
|
custom_llm_provider="vertex_ai_beta",
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -329,6 +329,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "ollama_chat"
|
or custom_llm_provider == "ollama_chat"
|
||||||
or custom_llm_provider == "replicate"
|
or custom_llm_provider == "replicate"
|
||||||
or custom_llm_provider == "vertex_ai"
|
or custom_llm_provider == "vertex_ai"
|
||||||
|
or custom_llm_provider == "vertex_ai_beta"
|
||||||
or custom_llm_provider == "gemini"
|
or custom_llm_provider == "gemini"
|
||||||
or custom_llm_provider == "sagemaker"
|
or custom_llm_provider == "sagemaker"
|
||||||
or custom_llm_provider == "anthropic"
|
or custom_llm_provider == "anthropic"
|
||||||
|
|
|
@ -503,28 +503,50 @@ async def test_async_vertexai_streaming_response():
|
||||||
# asyncio.run(test_async_vertexai_streaming_response())
|
# asyncio.run(test_async_vertexai_streaming_response())
|
||||||
|
|
||||||
|
|
||||||
def test_gemini_pro_vision():
|
@pytest.mark.parametrize("provider", ["vertex_ai", "vertex_ai_beta"])
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gemini_pro_vision(provider, sync_mode):
|
||||||
try:
|
try:
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
litellm.num_retries = 3
|
litellm.num_retries = 3
|
||||||
resp = litellm.completion(
|
if sync_mode:
|
||||||
model="vertex_ai/gemini-1.5-flash-preview-0514",
|
resp = litellm.completion(
|
||||||
messages=[
|
model="{}/gemini-1.5-flash-preview-0514".format(provider),
|
||||||
{
|
messages=[
|
||||||
"role": "user",
|
{
|
||||||
"content": [
|
"role": "user",
|
||||||
{"type": "text", "text": "Whats in this image?"},
|
"content": [
|
||||||
{
|
{"type": "text", "text": "Whats in this image?"},
|
||||||
"type": "image_url",
|
{
|
||||||
"image_url": {
|
"type": "image_url",
|
||||||
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
|
"image_url": {
|
||||||
|
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
],
|
||||||
],
|
}
|
||||||
}
|
],
|
||||||
],
|
)
|
||||||
)
|
else:
|
||||||
|
resp = await litellm.acompletion(
|
||||||
|
model="{}/gemini-1.5-flash-preview-0514".format(provider),
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Whats in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
print(resp)
|
print(resp)
|
||||||
|
|
||||||
prompt_tokens = resp.usage.prompt_tokens
|
prompt_tokens = resp.usage.prompt_tokens
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue