fix(main.py): support async streaming for text completions endpoint

This commit is contained in:
Krrish Dholakia 2023-12-14 13:56:32 -08:00
parent 7df9c8e4d8
commit 1608dd7e0b
7 changed files with 175 additions and 68 deletions

View file

@ -310,6 +310,45 @@ class Router:
return self.function_with_retries(**kwargs)
else:
raise e
async def atext_completion(self,
model: str,
prompt: str,
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
is_async: Optional[bool] = False,
**kwargs):
try:
kwargs.setdefault("metadata", {}).update({"model_group": model})
messages=[{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
# call via litellm.atext_completion()
response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
return response
except Exception as e:
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.completion
return self.function_with_retries(**kwargs)
else:
raise e
def embedding(self,
model: str,