keypoint가 라벨링된 json 파일(train)을 이용하여 이미지에 적용하여 학습하여
나머지 이미지(test)로 keypoint 찾기
# keypointrcnn_resnet50_fpn 모델 이용한 Keypoint 실습
PyTorch 제공하는 Object detection reference training scripts 다운로드
- 다운로드 사이트 https://github.com/pytorch/vision/tree/main/references/detection
Customdataset.py
import torch
import json
import cv2
import numpy as np
import os
from torch.utils.data import Dataset
from torchvision.transforms import functional as F
class KeypointDataset(Dataset):
def __init__(self, root, transform=None, demo=False):
self.demo = demo
self.root = root
self.imgs_files = sorted(os.listdir(os.path.join(self.root, "images")))
self.annotations_files = sorted(os.listdir(os.path.join(self.root, "annotations")))
self.transform = transform
def __getitem__(self, idx):
img_path = os.path.join(self.root, "images", self.imgs_files[idx])
annotations_path = os.path.join(self.root, "annotations", self.annotations_files[idx])
img_original = cv2.imread(img_path)
img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)
with open(annotations_path, "r", encoding="utf-8") as f:
data = json.load(f)
bboxes_original = data["bboxes"]
keypoints_original = data["keypoints"]
bboxes_labels_original = ['Glue tube' for _ in bboxes_original]
if self.transform:
keypoints_original_flattened = [el[0:2] for kp in keypoints_original for el in kp]
transformed = self.transform(image=img_original, bboxes=bboxes_original,
bboxes_labels=bboxes_labels_original,
keypoints=keypoints_original_flattened)
img = transformed["image"]
bboxes = transformed["bboxes"]
keypoints_transformed_unflattened = np.reshape(np.array(transformed["keypoints"]), (-1, 2, 2)).tolist()
keypoints = []
for o_idx, obj in enumerate(keypoints_transformed_unflattened):
obj_keypoints = []
for k_idx, kp in enumerate(obj):
obj_keypoints.append(kp + [keypoints_original[o_idx][k_idx][2]])
keypoints.append(obj_keypoints)
else:
img, bboxes, keypoints = img_original, bboxes_original, keypoints_original
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
target = {}
target["boxes"] = bboxes
target["labels"] = torch.as_tensor([1 for _ in bboxes], dtype=torch.int64)
target["image_id"] = torch.tensor([idx])
target["area"] = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
target["iscrowd"] = torch.zeros(len(bboxes), dtype=torch.int64)
target["keypoints"] = torch.as_tensor(keypoints, dtype=torch.float32)
img = F.to_tensor(img)
bboxes_original = torch.as_tensor(bboxes_original, dtype=torch.float32)
target_original = {}
target_original["boxes"] = bboxes_original
target_original["labels"] = torch.as_tensor([1 for _ in bboxes_original],
dtype=torch.int64)
target_original["image_id"] = torch.tensor([idx])
target_original["area"] = (bboxes_original[:, 3] - bboxes_original[:, 1]) * (
bboxes_original[:, 2] - bboxes_original[:, 0])
target_original["iscrowd"] = torch.zeros(len(bboxes_original), dtype=torch.int64)
target_original["keypoints"] = torch.as_tensor(keypoints_original, dtype=torch.float32)
img_original = F.to_tensor(img_original)
if self.demo:
return img, target, img_original, target_original
else:
return img, target
def __len__(self):
return len(self.imgs_files)
if __name__ == "__main__":
root_path = "./keypoint_dataset"
train_dataset = KeypointDataset(f"{root_path}/train")
for item in train_dataset:
print(item)
visualize.py
import cv2
import albumentations as A
from Customdataset import KeypointDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from utils import collate_fn
train_transform = A.Compose([
A.Sequential([
A.RandomRotate90(p=1),
A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, brightness_by_max=True,
always_apply=False, p=1)
], p=1)
], keypoint_params=A.KeypointParams(format='xy'),
bbox_params=A.BboxParams(format="pascal_voc", label_fields=['bboxes_labels'])
)
root_path = "./keypoint_dataset/"
dataset = KeypointDataset(f"{root_path}/train/", transform=train_transform, demo=True)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
iterator = iter(data_loader)
batch = next(iterator)
keypoints_classes_ids2names = {0: "Head", 1: "Tail"}
def visualize(image, bboxes, keypoints, image_original=None, bboxes_original=None, keypoints_original=None):
fontsize = 18
for bbox in bboxes:
start_point = (bbox[0], bbox[1])
end_point = (bbox[2], bbox[3])
image = cv2.rectangle(image.copy(), start_point, end_point, (0, 255, 0), 2)
for kpts in keypoints:
for idx, kp in enumerate(kpts):
image = cv2.circle(image.copy(), tuple(kp), 5, (255, 0, 0), 10)
image = cv2.putText(image.copy(), f" {keypoints_classes_ids2names[idx]}", tuple(kp),
cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 3, cv2.LINE_AA)
if image_original is None and keypoints_original is None:
plt.figure(figsize=(40, 40))
plt.imshow(image)
else:
for bbox in bboxes_original:
start_point = (bbox[0], bbox[1])
end_point = (bbox[2], bbox[3])
image_original = cv2.rectangle(image_original.copy(), start_point, end_point, (0, 255, 0), 2)
for kpts in keypoints_original:
for idx, kp in enumerate(kpts):
image_original = cv2.circle(image_original.copy(), tuple(kp), 5, (255, 0, 0), 10)
image_original = cv2.putText(image_original.copy(), f" {keypoints_classes_ids2names[idx]}", tuple(kp),
cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 3, cv2.LINE_AA)
f, ax = plt.subplots(1, 2, figsize=(40, 20))
ax[0].imshow(image_original)
ax[0].set_title("Original Image", fontsize=fontsize)
ax[1].imshow(image)
ax[1].set_title("Transformed Image", fontsize=fontsize)
plt.show()
if __name__=="__main__":
visualize_image_show = True
visualize_targets_show = True
image = (batch[0][0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
bboxes = batch[1][0]['boxes'].detach().cpu().numpy().astype(np.int32).tolist()
keypoints = []
for kpts in batch[1][0]['keypoints'].detach().cpu().numpy().astype(np.int32).tolist():
keypoints.append([kp[:2] for kp in kpts])
image_original = (batch[2][0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
bboxes_original = batch[3][0]['boxes'].detach().cpu().numpy().astype(np.int32).tolist()
keypoints_original = []
for kpts in batch[3][0]['keypoints'].detach().cpu().numpy().astype(np.int32).tolist():
keypoints_original.append([kp[:2] for kp in kpts])
if visualize_image_show:
visualize(image, bboxes, keypoints, image_original, bboxes_original, keypoints_original)
if visualize_targets_show and visualize_image_show == False:
print("Original targets: \n", batch[3], "\n\n")
print("Transformed targets: \n", batch[1])
main.py
import torch
import torchvision
import albumentations as A
from engine import train_one_epoch, evaluate
from utils import collate_fn
from torch.utils.data import DataLoader
from Customdataset import KeypointDataset
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.models.detection.rpn import AnchorGenerator
def get_model(num_keypoints, weights_path=None):
anchor_generator = AnchorGenerator(sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0))
model = keypointrcnn_resnet50_fpn(
pretrained=False,
pretrained_backbone=True,
num_classes=2,
num_keypoints=num_keypoints,
rpn_anchor_generator=anchor_generator)
if weights_path:
state_dict = torch.load(weights_path)
model.load_state_dict(state_dict)
return model
train_transform = A.Compose([
A.Sequential([
A.RandomRotate90(p=1),
A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, brightness_by_max=True,
always_apply=False, p=1)
], p=1)
], keypoint_params=A.KeypointParams(format='xy'),
bbox_params=A.BboxParams(format="pascal_voc", label_fields=['bboxes_labels'])
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
KEYPOINTS_FOLDER_TRAIN = "./keypoint_dataset/train/"
train_dataset = KeypointDataset(KEYPOINTS_FOLDER_TRAIN, transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=6, shuffle=True, collate_fn=collate_fn)
model = get_model(num_keypoints=2)
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.3)
num_epochs = 20
for epoch in range(num_epochs):
train_one_epoch(model, optimizer, train_dataloader, device, epoch, print_freq=1000)
lr_scheduler.step()
if epoch % 10 == 0:
torch.save(model.state_dict(), f"./keypointsrcnn_weights_{epoch}.pth")
torch.save(model.state_dict(), "./keypointsrcnn_weights_last.pth")