LLamaindex offers easy integrations with some like Blip from Salesforce that can understand images at least to some degree:
from dotenv import load_dotenv
load_dotenv()
from pathlib import Path
from llama_index import download_loader
from llama_index.readers.base import BaseReader
from typing import Dict, List, Optional
from dataclasses import dataclass
@dataclass
class Document:
text: str
metadata: Dict
@dataclass
class ImageDocument(Document):
image: str
class ImageCaptionReader(BaseReader):
def __init__(
self,
parser_config: Optional[Dict] = None,
keep_image: bool = False,
prompt: Optional[str] = None,
):
if parser_config is None:
try:
import sentencepiece
import torch
from PIL import Image
from transformers import BlipForConditionalGeneration, BlipProcessor
except ImportError:
raise ImportError(
"Please install extra dependencies that are required for "
"the ImageCaptionReader: "
"`pip install torch transformers sentencepiece Pillow`"
)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=dtype)
parser_config = {
"processor": processor,
"model": model,
"device": device,
"dtype": dtype,
}
self._parser_config = parser_config
self._keep_image = keep_image
self._prompt = prompt
def load_data(
self,
file: Path,
extra_info: Optional[Dict] = None
) -> List[Document]:
from PIL import Image
from llama_index.img_utils import img_2_b64
image = Image.open(file)
if image.mode != "RGB":
image = image.convert("RGB")
image_str: Optional[str] = None
if self._keep_image:
image_str = img_2_b64(image)
model = self._parser_config["model"]
processor = self._parser_config["processor"]
device = self._parser_config["device"]
dtype = self._parser_config["dtype"]
model.to(device)
inputs = processor(image, self._prompt, return_tensors="pt").to(device, dtype)
out = model.generate(**inputs)
text_str = processor.decode(out[0], skip_special_tokens=True)
return [
ImageDocument(
text=text_str,
image=image_str,
metadata=extra_info or {},
)
]
ImageCaptionReader = download_loader("ImageCaptionReader")
loader = ImageCaptionReader()
documents = loader.load_data(file=Path('image.png'))
for document in documents:
key, value = document
if key == "text":
print("Text:", value)