It needs to be an object class actually
from llama_index import QueryBundle
from llama_index.schema import NodeWithScore
from llama_index.retrievers import BaseRetriever
from typing import List
class CustomRetriever(BaseRetriever):
"""Custom retriever that performs both semantic search and hybrid search."""
def __init__(self, retriever, tool) -> None:
"""Init params."""
self._base_retriever = retriever
self._always_tool = tool
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve nodes given query."""
base_tools = self._base_retriever.retrieve(query_bundle)
return [self._always_tool] + base_tools
agent = FnRetrieverOpenAIAgent.from_retriever(
retriever=CustomRetriever(
obj_index.as_retriever(similarity_top_k=5),
query_engine_tool # this is the tool you always want to fetch
),
llm=llm,
callback_manager=callback_manager,
memory=_memory,
system_prompt=TEXT_QA_SYSTEM_PROMPT.content,
verbose=True,
)