Untitled
public
Mar 03, 2025
Never
24
1 import torch 2 from transformers import AutoModelForObjectDetection, AutoImageProcessor 3 import supervision as sv 4 import albumentations as A 5 from torch.utils.data import Dataset 6 import numpy as np 7 from torchmetrics.detection.mean_ap import MeanAveragePrecision 8 from transformers import ( 9 AutoImageProcessor, 10 AutoModelForObjectDetection, 11 TrainingArguments, 12 Trainer 13 ) 14 from PIL import Image 15 from dataclasses import dataclass 16 17 # @title Load model 18 19 @dataclass 20 class ModelOutput: 21 logits: torch.Tensor 22 pred_boxes: torch.Tensor 23 24 25 class MAPEvaluator: 26 27 def __init__(self, image_processor, threshold=0.00, id2label=None): 28 self.image_processor = image_processor 29 self.threshold = threshold 30 self.id2label = id2label 31 32 def collect_image_sizes(self, targets): 33 """Collect image sizes across the dataset as list of tensors with shape [batch_size, 2].""" 34 image_sizes = [] 35 for batch in targets: 36 batch_image_sizes = torch.tensor(np.array([x["size"] for x in batch])) 37 image_sizes.append(batch_image_sizes) 38 return image_sizes 39 40 def collect_targets(self, targets, image_sizes): 41 post_processed_targets = [] 42 for target_batch, image_size_batch in zip(targets, image_sizes): 43 for target, (height, width) in zip(target_batch, image_size_batch): 44 boxes = target["boxes"] 45 boxes = sv.xcycwh_to_xyxy(boxes) 46 boxes = boxes * np.array([width, height, width, height]) 47 boxes = torch.tensor(boxes) 48 labels = torch.tensor(target["class_labels"]) 49 post_processed_targets.append({"boxes": boxes, "labels": labels}) 50 return post_processed_targets 51 52 def collect_predictions(self, predictions, image_sizes): 53 post_processed_predictions = [] 54 for batch, target_sizes in zip(predictions, image_sizes): 55 batch_logits, batch_boxes = batch[1], batch[2] 56 output = ModelOutput(logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes)) 57 post_processed_output = self.image_processor.post_process_object_detection( 58 output, threshold=self.threshold, target_sizes=target_sizes 59 ) 60 post_processed_predictions.extend(post_processed_output) 61 return post_processed_predictions 62 63 @torch.no_grad() 64 def __call__(self, evaluation_results): 65 66 predictions, targets = evaluation_results.predictions, evaluation_results.label_ids 67 68 image_sizes = self.collect_image_sizes(targets) 69 post_processed_targets = self.collect_targets(targets, image_sizes) 70 post_processed_predictions = self.collect_predictions(predictions, image_sizes) 71 72 evaluator = MeanAveragePrecision(box_format="xyxy", class_metrics=True) 73 evaluator.warn_on_many_detections = False 74 evaluator.update(post_processed_predictions, post_processed_targets) 75 76 metrics = evaluator.compute() 77 78 # Replace list of per class metrics with separate metric for each class 79 classes = metrics.pop("classes") 80 map_per_class = metrics.pop("map_per_class") 81 mar_100_per_class = metrics.pop("mar_100_per_class") 82 for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class): 83 class_name = self.id2label[class_id.item()] if self.id2label is not None else class_id.item() 84 metrics[f"map_{class_name}"] = class_map 85 metrics[f"mar_100_{class_name}"] = class_mar 86 87 metrics = {k: round(v.item(), 4) for k, v in metrics.items()} 88 89 return metrics 90 91 class PyTorchDetectionDataset(Dataset): 92 def __init__(self, dataset: sv.DetectionDataset, processor, transform: A.Compose = None): 93 self.dataset = dataset 94 self.processor = processor 95 self.transform = transform 96 97 @staticmethod 98 def annotations_as_coco(image_id, categories, boxes): 99 annotations = [] 100 for category, bbox in zip(categories, boxes): 101 x1, y1, x2, y2 = bbox 102 formatted_annotation = { 103 "image_id": image_id, 104 "category_id": category, 105 "bbox": [x1, y1, x2 - x1, y2 - y1], 106 "iscrowd": 0, 107 "area": (x2 - x1) * (y2 - y1), 108 } 109 annotations.append(formatted_annotation) 110 111 return { 112 "image_id": image_id, 113 "annotations": annotations, 114 } 115 116 def __len__(self): 117 return len(self.dataset) 118 119 def __getitem__(self, idx): 120 _, image, annotations = self.dataset[idx] 121 122 # Convert image to RGB numpy array 123 image = image[:, :, ::-1] 124 boxes = annotations.xyxy 125 categories = annotations.class_id 126 127 if self.transform: 128 transformed = self.transform( 129 image=image, 130 bboxes=boxes, 131 category=categories 132 ) 133 image = transformed["image"] 134 boxes = transformed["bboxes"] 135 categories = transformed["category"] 136 137 138 formatted_annotations = self.annotations_as_coco( 139 image_id=idx, categories=categories, boxes=boxes) 140 result = self.processor( 141 images=image, annotations=formatted_annotations, return_tensors="pt") 142 143 # Image processor expands batch dimension, lets squeeze it 144 result = {k: v[0] for k, v in result.items()} 145 146 return result 147 148 def annotate(image, annotations, classes): 149 labels = [ 150 classes[class_id] 151 for class_id 152 in annotations.class_id 153 ] 154 155 bounding_box_annotator = sv.BoundingBoxAnnotator() 156 label_annotator = sv.LabelAnnotator(text_scale=1, text_thickness=2) 157 158 annotated_image = image.copy() 159 annotated_image = bounding_box_annotator.annotate(annotated_image, annotations) 160 annotated_image = label_annotator.annotate(annotated_image, annotations, labels=labels) 161 return annotated_image 162 163 def collate_fn(batch): 164 data = {} 165 data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch]) 166 data["labels"] = [x["labels"] for x in batch] 167 return data 168 169 if __name__ == "__main__": 170 CHECKPOINT = "PekingU/rtdetr_r50vd_coco_o365" 171 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 172 DATASET_ROOT = "Datasets/Custom_dataset" 173 174 model = AutoModelForObjectDetection.from_pretrained(CHECKPOINT).to(DEVICE) 175 image_processor = AutoImageProcessor.from_pretrained(CHECKPOINT) 176 177 ds_train = sv.DetectionDataset.from_coco( 178 images_directory_path=F"{DATASET_ROOT}/images/Train", 179 annotations_path=F"{DATASET_ROOT}/images/Train/annotations_coco.json", 180 ) 181 ds_valid = sv.DetectionDataset.from_coco( 182 images_directory_path=f"{DATASET_ROOT}/images/Validate", 183 annotations_path=f"{DATASET_ROOT}/images/Validate/annotations_coco.json", 184 ) 185 ds_test = sv.DetectionDataset.from_coco( 186 images_directory_path=f"{DATASET_ROOT}/images/Test", 187 annotations_path=f"{DATASET_ROOT}/images/Test/annotations_coco.json", 188 ) 189 190 print(f"Number of training images: {len(ds_train)}") 191 print(f"Number of validation images: {len(ds_valid)}") 192 print(f"Number of test images: {len(ds_test)}") 193 194 augmentation_train = A.Compose( 195 [ 196 A.Perspective(p=0.1), 197 A.HorizontalFlip(p=0.5), 198 A.RandomBrightnessContrast(p=0.5), 199 A.HueSaturationValue(p=0.1), 200 ], 201 bbox_params=A.BboxParams( 202 format="pascal_voc", 203 label_fields=["category"], 204 clip=True, 205 min_area=25 206 ), 207 ) 208 209 augmentation_valid = A.Compose( 210 [A.NoOp()], 211 bbox_params=A.BboxParams( 212 format="pascal_voc", 213 label_fields=["category"], 214 clip=True, 215 min_area=1 216 ), 217 ) 218 219 pytorch_dataset_train = PyTorchDetectionDataset(ds_train, image_processor, transform=None) 220 pytorch_dataset_valid = PyTorchDetectionDataset(ds_valid, image_processor, transform=None) 221 pytorch_dataset_test = PyTorchDetectionDataset(ds_test, image_processor, transform=None) 222 223 id2label = {id: label for id, label in enumerate(ds_train.classes)} 224 label2id = {label: id for id, label in enumerate(ds_train.classes)} 225 226 eval_compute_metrics_fn = MAPEvaluator(image_processor=image_processor, threshold=0.01, id2label=id2label) 227 228 model = AutoModelForObjectDetection.from_pretrained( 229 CHECKPOINT, 230 id2label=id2label, 231 label2id=label2id, 232 anchor_image_size=None, 233 ignore_mismatched_sizes=True, 234 ) 235 236 training_args = TrainingArguments( 237 output_dir=f"rtdetr-finetune3-100cd", 238 num_train_epochs=100, 239 max_grad_norm=0.1, 240 learning_rate=1e-4, 241 warmup_steps=200, 242 per_device_train_batch_size=8, 243 dataloader_num_workers=1, 244 metric_for_best_model="eval_map", 245 greater_is_better=True, 246 load_best_model_at_end=True, 247 eval_strategy="epoch", 248 save_strategy="epoch", 249 save_total_limit=4, 250 remove_unused_columns=False, 251 eval_do_concat_batches=False, 252 ) 253 254 trainer = Trainer( 255 model=model, 256 args=training_args, 257 train_dataset=pytorch_dataset_train, 258 eval_dataset=pytorch_dataset_valid, 259 processing_class=image_processor, 260 data_collator=collate_fn, 261 compute_metrics=eval_compute_metrics_fn, 262 ) 263 264 trainer.train() 265 266 targets = [] 267 predictions = [] 268 269 for i in range(len(ds_test)): 270 path, sourece_image, annotations = ds_test[i] 271 272 image = Image.open(path) 273 inputs = image_processor(image, return_tensors="pt").to(DEVICE) 274 275 with torch.no_grad(): 276 outputs = model(**inputs) 277 278 w, h = image.size 279 results = image_processor.post_process_object_detection( 280 outputs, target_sizes=[(h, w)], threshold=0.3) 281 282 detections = sv.Detections.from_transformers(results[0]) 283 284 targets.append(annotations) 285 predictions.append(detections) 286 287 # @title Calculate mAP 288 mean_average_precision = sv.MeanAveragePrecision.from_detections( 289 predictions=predictions, 290 targets=targets, 291 ) 292 293 print(f"map50_95: {mean_average_precision.map50_95:.2f}") 294 print(f"map50: {mean_average_precision.map50:.2f}") 295 print(f"map75: {mean_average_precision.map75:.2f}")