def draw_bar_graph(self, title, x_label, y_label, x_data, y_data):
plt.bar(x_data, y_data)
plt.title(title)
plt.xlabel(x_label)
plt.ylabel(y_label)
image_name = str(uuid.uuid4())
image_path = os.path.join('src/assets', f'{image_name}.png')
plt.savefig(image_path, format='png')
return f"{image_name}.png"
class GraphTool(BaseTool):
"""Open Graph tool spec."""
spec_functions = ["draw_bar_graph"]
def __init__(
self,
metadata: ToolMetadata,
) -> None:
self._metadata = metadata
def draw_bar_graph(self, title: str, x_label: str, y_label: str, x_data: List[str], y_data: List[str]) -> List[Document]:
"""
Finds Draw a bar graph with provided values.
Args:
graph_title (str):
The title of the graph.
x_label (str):
The label of the x-axis of the graph.
y_label (str):
A label of the y-axis of the graph.
x_data (List[str]):
The list of values for the x-axis that labels the bars.
y_data (List[float]):
The data for the y-axis containing numerical values representing the heights or values of the bars.
"""
plt.bar(x_data, y_data)
plt.title(title)
plt.xlabel(x_label)
plt.ylabel(y_label)
image_name = str(uuid.uuid4())
image_path = os.path.join('src/assets', f'{image_name}.png')
plt.savefig(image_path, format='png')
return [Document(text=f"{image_name}.png", metadata={"graph title": title})]