I'm trying to troubleshoot an issue with my query pipeline. I'm trying to run the join on the output of worker_llms (llms run in parallel) but the join always seems to run right after worker_query and before the worker_llms. When I visualize the DAG it looks correct.
def get_judge_engine() -> QueryPipeline:
# Setup LLMs
judge_llm = Settings.llm
worker_llms = {}
num_workers = int(os.getenv("NUM_WORKERS", "5"))
for i in range(num_workers):
worker_llms[str(i)] = Settings.llm
# Construct the query pipeline
p = QueryPipeline(verbose=True)
# Define the pipeline nodes.
module_dict = {
**worker_llms,
"worker_query": WORKER_PROMPT_TMPL,
"judge_query": JUDGE_PROMPT_TMPL,
"llm_judge": judge_llm,
"join": ArgPackComponent(),
}
p.add_modules(module_dict)
# Add links between nodes
for i in range(num_workers):
p.add_link("worker_query", str(i))
p.add_link(str(i), "join", dest_key=str(i))
p.add_link("join", "judge_query", dest_key="context_str")
p.add_link("judge_query", "llm_judge")
# Generate visualization
net = Network(directed=True)
net.from_nx(p.dag)
net.save_graph("rag_dag.html")
# Return the final pipeline.
return p