G

Untitled

public
Guest Mar 03, 2025 Never 24
Clone
Python paste1.py 295 lines (242 loc) | 10.31 KB
1
import torch
2
from transformers import AutoModelForObjectDetection, AutoImageProcessor
3
import supervision as sv
4
import albumentations as A
5
from torch.utils.data import Dataset
6
import numpy as np
7
from torchmetrics.detection.mean_ap import MeanAveragePrecision
8
from transformers import (
9
AutoImageProcessor,
10
AutoModelForObjectDetection,
11
TrainingArguments,
12
Trainer
13
)
14
from PIL import Image
15
from dataclasses import dataclass
16
17
# @title Load model
18
19
@dataclass
20
class ModelOutput:
21
logits: torch.Tensor
22
pred_boxes: torch.Tensor
23
24
25
class MAPEvaluator:
26
27
def __init__(self, image_processor, threshold=0.00, id2label=None):
28
self.image_processor = image_processor
29
self.threshold = threshold
30
self.id2label = id2label
31
32
def collect_image_sizes(self, targets):
33
"""Collect image sizes across the dataset as list of tensors with shape [batch_size, 2]."""
34
image_sizes = []
35
for batch in targets:
36
batch_image_sizes = torch.tensor(np.array([x["size"] for x in batch]))
37
image_sizes.append(batch_image_sizes)
38
return image_sizes
39
40
def collect_targets(self, targets, image_sizes):
41
post_processed_targets = []
42
for target_batch, image_size_batch in zip(targets, image_sizes):
43
for target, (height, width) in zip(target_batch, image_size_batch):
44
boxes = target["boxes"]
45
boxes = sv.xcycwh_to_xyxy(boxes)
46
boxes = boxes * np.array([width, height, width, height])
47
boxes = torch.tensor(boxes)
48
labels = torch.tensor(target["class_labels"])
49
post_processed_targets.append({"boxes": boxes, "labels": labels})
50
return post_processed_targets
51
52
def collect_predictions(self, predictions, image_sizes):
53
post_processed_predictions = []
54
for batch, target_sizes in zip(predictions, image_sizes):
55
batch_logits, batch_boxes = batch[1], batch[2]
56
output = ModelOutput(logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes))
57
post_processed_output = self.image_processor.post_process_object_detection(
58
output, threshold=self.threshold, target_sizes=target_sizes
59
)
60
post_processed_predictions.extend(post_processed_output)
61
return post_processed_predictions
62
63
@torch.no_grad()
64
def __call__(self, evaluation_results):
65
66
predictions, targets = evaluation_results.predictions, evaluation_results.label_ids
67
68
image_sizes = self.collect_image_sizes(targets)
69
post_processed_targets = self.collect_targets(targets, image_sizes)
70
post_processed_predictions = self.collect_predictions(predictions, image_sizes)
71
72
evaluator = MeanAveragePrecision(box_format="xyxy", class_metrics=True)
73
evaluator.warn_on_many_detections = False
74
evaluator.update(post_processed_predictions, post_processed_targets)
75
76
metrics = evaluator.compute()
77
78
# Replace list of per class metrics with separate metric for each class
79
classes = metrics.pop("classes")
80
map_per_class = metrics.pop("map_per_class")
81
mar_100_per_class = metrics.pop("mar_100_per_class")
82
for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
83
class_name = self.id2label[class_id.item()] if self.id2label is not None else class_id.item()
84
metrics[f"map_{class_name}"] = class_map
85
metrics[f"mar_100_{class_name}"] = class_mar
86
87
metrics = {k: round(v.item(), 4) for k, v in metrics.items()}
88
89
return metrics
90
91
class PyTorchDetectionDataset(Dataset):
92
def __init__(self, dataset: sv.DetectionDataset, processor, transform: A.Compose = None):
93
self.dataset = dataset
94
self.processor = processor
95
self.transform = transform
96
97
@staticmethod
98
def annotations_as_coco(image_id, categories, boxes):
99
annotations = []
100
for category, bbox in zip(categories, boxes):
101
x1, y1, x2, y2 = bbox
102
formatted_annotation = {
103
"image_id": image_id,
104
"category_id": category,
105
"bbox": [x1, y1, x2 - x1, y2 - y1],
106
"iscrowd": 0,
107
"area": (x2 - x1) * (y2 - y1),
108
}
109
annotations.append(formatted_annotation)
110
111
return {
112
"image_id": image_id,
113
"annotations": annotations,
114
}
115
116
def __len__(self):
117
return len(self.dataset)
118
119
def __getitem__(self, idx):
120
_, image, annotations = self.dataset[idx]
121
122
# Convert image to RGB numpy array
123
image = image[:, :, ::-1]
124
boxes = annotations.xyxy
125
categories = annotations.class_id
126
127
if self.transform:
128
transformed = self.transform(
129
image=image,
130
bboxes=boxes,
131
category=categories
132
)
133
image = transformed["image"]
134
boxes = transformed["bboxes"]
135
categories = transformed["category"]
136
137
138
formatted_annotations = self.annotations_as_coco(
139
image_id=idx, categories=categories, boxes=boxes)
140
result = self.processor(
141
images=image, annotations=formatted_annotations, return_tensors="pt")
142
143
# Image processor expands batch dimension, lets squeeze it
144
result = {k: v[0] for k, v in result.items()}
145
146
return result
147
148
def annotate(image, annotations, classes):
149
labels = [
150
classes[class_id]
151
for class_id
152
in annotations.class_id
153
]
154
155
bounding_box_annotator = sv.BoundingBoxAnnotator()
156
label_annotator = sv.LabelAnnotator(text_scale=1, text_thickness=2)
157
158
annotated_image = image.copy()
159
annotated_image = bounding_box_annotator.annotate(annotated_image, annotations)
160
annotated_image = label_annotator.annotate(annotated_image, annotations, labels=labels)
161
return annotated_image
162
163
def collate_fn(batch):
164
data = {}
165
data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
166
data["labels"] = [x["labels"] for x in batch]
167
return data
168
169
if __name__ == "__main__":
170
CHECKPOINT = "PekingU/rtdetr_r50vd_coco_o365"
171
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172
DATASET_ROOT = "Datasets/Custom_dataset"
173
174
model = AutoModelForObjectDetection.from_pretrained(CHECKPOINT).to(DEVICE)
175
image_processor = AutoImageProcessor.from_pretrained(CHECKPOINT)
176
177
ds_train = sv.DetectionDataset.from_coco(
178
images_directory_path=F"{DATASET_ROOT}/images/Train",
179
annotations_path=F"{DATASET_ROOT}/images/Train/annotations_coco.json",
180
)
181
ds_valid = sv.DetectionDataset.from_coco(
182
images_directory_path=f"{DATASET_ROOT}/images/Validate",
183
annotations_path=f"{DATASET_ROOT}/images/Validate/annotations_coco.json",
184
)
185
ds_test = sv.DetectionDataset.from_coco(
186
images_directory_path=f"{DATASET_ROOT}/images/Test",
187
annotations_path=f"{DATASET_ROOT}/images/Test/annotations_coco.json",
188
)
189
190
print(f"Number of training images: {len(ds_train)}")
191
print(f"Number of validation images: {len(ds_valid)}")
192
print(f"Number of test images: {len(ds_test)}")
193
194
augmentation_train = A.Compose(
195
[
196
A.Perspective(p=0.1),
197
A.HorizontalFlip(p=0.5),
198
A.RandomBrightnessContrast(p=0.5),
199
A.HueSaturationValue(p=0.1),
200
],
201
bbox_params=A.BboxParams(
202
format="pascal_voc",
203
label_fields=["category"],
204
clip=True,
205
min_area=25
206
),
207
)
208
209
augmentation_valid = A.Compose(
210
[A.NoOp()],
211
bbox_params=A.BboxParams(
212
format="pascal_voc",
213
label_fields=["category"],
214
clip=True,
215
min_area=1
216
),
217
)
218
219
pytorch_dataset_train = PyTorchDetectionDataset(ds_train, image_processor, transform=None)
220
pytorch_dataset_valid = PyTorchDetectionDataset(ds_valid, image_processor, transform=None)
221
pytorch_dataset_test = PyTorchDetectionDataset(ds_test, image_processor, transform=None)
222
223
id2label = {id: label for id, label in enumerate(ds_train.classes)}
224
label2id = {label: id for id, label in enumerate(ds_train.classes)}
225
226
eval_compute_metrics_fn = MAPEvaluator(image_processor=image_processor, threshold=0.01, id2label=id2label)
227
228
model = AutoModelForObjectDetection.from_pretrained(
229
CHECKPOINT,
230
id2label=id2label,
231
label2id=label2id,
232
anchor_image_size=None,
233
ignore_mismatched_sizes=True,
234
)
235
236
training_args = TrainingArguments(
237
output_dir=f"rtdetr-finetune3-100cd",
238
num_train_epochs=100,
239
max_grad_norm=0.1,
240
learning_rate=1e-4,
241
warmup_steps=200,
242
per_device_train_batch_size=8,
243
dataloader_num_workers=1,
244
metric_for_best_model="eval_map",
245
greater_is_better=True,
246
load_best_model_at_end=True,
247
eval_strategy="epoch",
248
save_strategy="epoch",
249
save_total_limit=4,
250
remove_unused_columns=False,
251
eval_do_concat_batches=False,
252
)
253
254
trainer = Trainer(
255
model=model,
256
args=training_args,
257
train_dataset=pytorch_dataset_train,
258
eval_dataset=pytorch_dataset_valid,
259
processing_class=image_processor,
260
data_collator=collate_fn,
261
compute_metrics=eval_compute_metrics_fn,
262
)
263
264
trainer.train()
265
266
targets = []
267
predictions = []
268
269
for i in range(len(ds_test)):
270
path, sourece_image, annotations = ds_test[i]
271
272
image = Image.open(path)
273
inputs = image_processor(image, return_tensors="pt").to(DEVICE)
274
275
with torch.no_grad():
276
outputs = model(**inputs)
277
278
w, h = image.size
279
results = image_processor.post_process_object_detection(
280
outputs, target_sizes=[(h, w)], threshold=0.3)
281
282
detections = sv.Detections.from_transformers(results[0])
283
284
targets.append(annotations)
285
predictions.append(detections)
286
287
# @title Calculate mAP
288
mean_average_precision = sv.MeanAveragePrecision.from_detections(
289
predictions=predictions,
290
targets=targets,
291
)
292
293
print(f"map50_95: {mean_average_precision.map50_95:.2f}")
294
print(f"map50: {mean_average_precision.map50:.2f}")
295
print(f"map75: {mean_average_precision.map75:.2f}")