feat(bedrock_httpx.py): working bedrock command-r sync+async streaming

This commit is contained in:
Krrish Dholakia 2024-05-11 19:39:51 -07:00
parent 49ab1a1d3f
commit 64650c0279
6 changed files with 342 additions and 51 deletions

View file

@ -257,7 +257,7 @@ async def acompletion(
- If `stream` is True, the function returns an async generator that yields completion lines.
"""
loop = asyncio.get_event_loop()
custom_llm_provider = None
custom_llm_provider = kwargs.get("custom_llm_provider", None)
# Adjusted to use explicit arguments instead of *args and **kwargs
completion_kwargs = {
"model": model,
@ -289,9 +289,10 @@ async def acompletion(
"model_list": model_list,
"acompletion": True, # assuming this is a required parameter
}
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=completion_kwargs.get("base_url", None)
)
if custom_llm_provider is None:
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=completion_kwargs.get("base_url", None)
)
try:
# Use a partial function to pass your keyword arguments
func = partial(completion, **completion_kwargs, **kwargs)
@ -300,9 +301,6 @@ async def acompletion(
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
@ -324,6 +322,7 @@ async def acompletion(
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"
or custom_llm_provider == "predibase"
or (custom_llm_provider == "bedrock" and "cohere" in model)
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
@ -1937,6 +1936,7 @@ def completion(
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
)
else:
response = bedrock.completion(
@ -1954,26 +1954,26 @@ def completion(
timeout=timeout,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
if "ai21" in model:
response = CustomStreamWrapper(
response,
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
else:
response = CustomStreamWrapper(
iter(response),
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
if "ai21" in model:
response = CustomStreamWrapper(
response,
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
else:
response = CustomStreamWrapper(
iter(response),
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
if optional_params.get("stream", False):
## LOGGING