def table_index_builder(sql_database, table_context_dict: dict[str, Literal]) -> ObjectIndex:
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs: list[SQLTableSchema] = []
# Build our table schema
for table_name in table_context_dict:
# one SQLTableSchema for each table
table_schema = (SQLTableSchema(
table_name=table_name, context_str=table_context_dict[table_name]))
table_schema_objs.append(table_schema)
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
return obj_index
def build_query_engine(sql_database: SQLDatabase, service_context: ServiceContext, table_context_dict: dict[str, Literal]) -> SQLTableRetrieverQueryEngine:
# build object index
obj_index = table_index_builder(
sql_database,
table_context_dict=table_context_dict)
query_engine = SQLTableRetrieverQueryEngine(
sql_database=sql_database,
table_retriever=obj_index.as_retriever(similarity_top_k=2),
service_context=service_context
)
return query_engine
obj_index = ObjectIndex.from_objects(...) obj_index._index.storage_context.persist(persist_dir="./storage") from llama_index import StorageContext, load_index_from_storage index = load_index_from_storage(StorageContext.from_defaults(persist_dir="./storage")) obj_index = ObjectIndex(index, table_node_mapping)