Merge branch 'main' into feat/litellm_sambanova_usage

This commit is contained in:
Jorge Piedrahita Ortiz 2025-05-06 11:56:33 -05:00 committed by GitHub
commit c91c45756b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 752 additions and 188 deletions

View file

@ -139,6 +139,8 @@ class OllamaInferenceAdapter(
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -202,6 +204,8 @@ class OllamaInferenceAdapter(
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
@ -346,6 +350,8 @@ class OllamaInferenceAdapter(
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id is None:
raise ValueError("Model provider_resource_id cannot be None")
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None:
provider_resource_id = model.provider_resource_id

View file

@ -272,6 +272,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -302,6 +304,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice

View file

@ -382,7 +382,7 @@ def augment_messages_for_tools_llama_3_1(
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:

View file

@ -118,27 +118,25 @@ async def content_from_doc(doc: RAGDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
return content_from_data(doc.content.uri)
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
else:
return r.text
pattern = re.compile("^(https?://|file://|data:)")
if pattern.match(doc.content):
if doc.content.startswith("data:"):
return content_from_data(doc.content)
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
return r.text
elif isinstance(doc.content, str):
pattern = re.compile("^(https?://|file://|data:)")
if pattern.match(doc.content):
if doc.content.startswith("data:"):
return content_from_data(doc.content)
async with httpx.AsyncClient() as client:
r = await client.get(doc.content)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
else:
return r.text
return interleaved_content_as_str(doc.content)
return r.text
return doc.content
else:
# will raise ValueError if the content is not List[InterleavedContent] or InterleavedContent
return interleaved_content_as_str(doc.content)
def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> list[Chunk]: