Spaces:
Running
on
T4
Running
on
T4
Upload 161 files
#5
by
zxairdeep
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- ultralytics/.pre-commit-config.yaml +73 -0
- ultralytics/__init__.py +12 -0
- ultralytics/assets/bus.jpg +3 -0
- ultralytics/assets/zidane.jpg +3 -0
- ultralytics/datasets/Argoverse.yaml +73 -0
- ultralytics/datasets/GlobalWheat2020.yaml +54 -0
- ultralytics/datasets/ImageNet.yaml +2025 -0
- ultralytics/datasets/Objects365.yaml +443 -0
- ultralytics/datasets/SKU-110K.yaml +58 -0
- ultralytics/datasets/VOC.yaml +100 -0
- ultralytics/datasets/VisDrone.yaml +73 -0
- ultralytics/datasets/coco-pose.yaml +38 -0
- ultralytics/datasets/coco.yaml +115 -0
- ultralytics/datasets/coco128-seg.yaml +101 -0
- ultralytics/datasets/coco128.yaml +101 -0
- ultralytics/datasets/coco8-pose.yaml +25 -0
- ultralytics/datasets/coco8-seg.yaml +101 -0
- ultralytics/datasets/coco8.yaml +101 -0
- ultralytics/datasets/xView.yaml +153 -0
- ultralytics/hub/__init__.py +117 -0
- ultralytics/hub/auth.py +139 -0
- ultralytics/hub/session.py +189 -0
- ultralytics/hub/utils.py +217 -0
- ultralytics/models/README.md +45 -0
- ultralytics/models/rt-detr/rtdetr-l.yaml +50 -0
- ultralytics/models/rt-detr/rtdetr-x.yaml +54 -0
- ultralytics/models/v3/yolov3-spp.yaml +48 -0
- ultralytics/models/v3/yolov3-tiny.yaml +39 -0
- ultralytics/models/v3/yolov3.yaml +48 -0
- ultralytics/models/v5/yolov5-p6.yaml +61 -0
- ultralytics/models/v5/yolov5.yaml +50 -0
- ultralytics/models/v6/yolov6.yaml +53 -0
- ultralytics/models/v8/yolov8-cls.yaml +29 -0
- ultralytics/models/v8/yolov8-p2.yaml +54 -0
- ultralytics/models/v8/yolov8-p6.yaml +56 -0
- ultralytics/models/v8/yolov8-pose-p6.yaml +57 -0
- ultralytics/models/v8/yolov8-pose.yaml +47 -0
- ultralytics/models/v8/yolov8-rtdetr.yaml +46 -0
- ultralytics/models/v8/yolov8-seg.yaml +46 -0
- ultralytics/models/v8/yolov8.yaml +46 -0
- ultralytics/nn/__init__.py +9 -0
- ultralytics/nn/autobackend.py +455 -0
- ultralytics/nn/autoshape.py +244 -0
- ultralytics/nn/modules/__init__.py +29 -0
- ultralytics/nn/modules/block.py +304 -0
- ultralytics/nn/modules/conv.py +297 -0
- ultralytics/nn/modules/head.py +349 -0
- ultralytics/nn/modules/transformer.py +378 -0
- ultralytics/nn/modules/utils.py +78 -0
- ultralytics/nn/tasks.py +780 -0
ultralytics/.pre-commit-config.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# Pre-commit hooks. For more information see https://github.com/pre-commit/pre-commit-hooks/blob/main/README.md
|
3 |
+
|
4 |
+
exclude: 'docs/'
|
5 |
+
# Define bot property if installed via https://github.com/marketplace/pre-commit-ci
|
6 |
+
ci:
|
7 |
+
autofix_prs: true
|
8 |
+
autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
|
9 |
+
autoupdate_schedule: monthly
|
10 |
+
# submodules: true
|
11 |
+
|
12 |
+
repos:
|
13 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
14 |
+
rev: v4.4.0
|
15 |
+
hooks:
|
16 |
+
- id: end-of-file-fixer
|
17 |
+
- id: trailing-whitespace
|
18 |
+
- id: check-case-conflict
|
19 |
+
# - id: check-yaml
|
20 |
+
- id: check-docstring-first
|
21 |
+
- id: double-quote-string-fixer
|
22 |
+
- id: detect-private-key
|
23 |
+
|
24 |
+
- repo: https://github.com/asottile/pyupgrade
|
25 |
+
rev: v3.4.0
|
26 |
+
hooks:
|
27 |
+
- id: pyupgrade
|
28 |
+
name: Upgrade code
|
29 |
+
|
30 |
+
- repo: https://github.com/PyCQA/isort
|
31 |
+
rev: 5.12.0
|
32 |
+
hooks:
|
33 |
+
- id: isort
|
34 |
+
name: Sort imports
|
35 |
+
|
36 |
+
- repo: https://github.com/google/yapf
|
37 |
+
rev: v0.33.0
|
38 |
+
hooks:
|
39 |
+
- id: yapf
|
40 |
+
name: YAPF formatting
|
41 |
+
|
42 |
+
- repo: https://github.com/executablebooks/mdformat
|
43 |
+
rev: 0.7.16
|
44 |
+
hooks:
|
45 |
+
- id: mdformat
|
46 |
+
name: MD formatting
|
47 |
+
additional_dependencies:
|
48 |
+
- mdformat-gfm
|
49 |
+
- mdformat-black
|
50 |
+
# exclude: "README.md|README.zh-CN.md|CONTRIBUTING.md"
|
51 |
+
|
52 |
+
- repo: https://github.com/PyCQA/flake8
|
53 |
+
rev: 6.0.0
|
54 |
+
hooks:
|
55 |
+
- id: flake8
|
56 |
+
name: PEP8
|
57 |
+
|
58 |
+
- repo: https://github.com/codespell-project/codespell
|
59 |
+
rev: v2.2.4
|
60 |
+
hooks:
|
61 |
+
- id: codespell
|
62 |
+
args:
|
63 |
+
- --ignore-words-list=crate,nd,strack,dota
|
64 |
+
|
65 |
+
# - repo: https://github.com/asottile/yesqa
|
66 |
+
# rev: v1.4.0
|
67 |
+
# hooks:
|
68 |
+
# - id: yesqa
|
69 |
+
|
70 |
+
# - repo: https://github.com/asottile/dead
|
71 |
+
# rev: v1.5.0
|
72 |
+
# hooks:
|
73 |
+
# - id: dead
|
ultralytics/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
__version__ = '8.0.120'
|
4 |
+
|
5 |
+
from ultralytics.hub import start
|
6 |
+
from ultralytics.vit.rtdetr import RTDETR
|
7 |
+
from ultralytics.vit.sam import SAM
|
8 |
+
from ultralytics.yolo.engine.model import YOLO
|
9 |
+
from ultralytics.yolo.nas import NAS
|
10 |
+
from ultralytics.yolo.utils.checks import check_yolo as checks
|
11 |
+
|
12 |
+
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'RTDETR', 'checks', 'start' # allow simpler import
|
ultralytics/assets/bus.jpg
ADDED
Git LFS Details
|
ultralytics/assets/zidane.jpg
ADDED
Git LFS Details
|
ultralytics/datasets/Argoverse.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/ by Argo AI
|
3 |
+
# Example usage: yolo train data=Argoverse.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── Argoverse ← downloads here (31.3 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/Argoverse # dataset root dir
|
12 |
+
train: Argoverse-1.1/images/train/ # train images (relative to 'path') 39384 images
|
13 |
+
val: Argoverse-1.1/images/val/ # val images (relative to 'path') 15062 images
|
14 |
+
test: Argoverse-1.1/images/test/ # test images (optional) https://eval.ai/web/challenges/challenge-page/800/overview
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: bus
|
23 |
+
5: truck
|
24 |
+
6: traffic_light
|
25 |
+
7: stop_sign
|
26 |
+
|
27 |
+
|
28 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
29 |
+
download: |
|
30 |
+
import json
|
31 |
+
from tqdm import tqdm
|
32 |
+
from ultralytics.yolo.utils.downloads import download
|
33 |
+
from pathlib import Path
|
34 |
+
|
35 |
+
def argoverse2yolo(set):
|
36 |
+
labels = {}
|
37 |
+
a = json.load(open(set, "rb"))
|
38 |
+
for annot in tqdm(a['annotations'], desc=f"Converting {set} to YOLOv5 format..."):
|
39 |
+
img_id = annot['image_id']
|
40 |
+
img_name = a['images'][img_id]['name']
|
41 |
+
img_label_name = f'{img_name[:-3]}txt'
|
42 |
+
|
43 |
+
cls = annot['category_id'] # instance class id
|
44 |
+
x_center, y_center, width, height = annot['bbox']
|
45 |
+
x_center = (x_center + width / 2) / 1920.0 # offset and scale
|
46 |
+
y_center = (y_center + height / 2) / 1200.0 # offset and scale
|
47 |
+
width /= 1920.0 # scale
|
48 |
+
height /= 1200.0 # scale
|
49 |
+
|
50 |
+
img_dir = set.parents[2] / 'Argoverse-1.1' / 'labels' / a['seq_dirs'][a['images'][annot['image_id']]['sid']]
|
51 |
+
if not img_dir.exists():
|
52 |
+
img_dir.mkdir(parents=True, exist_ok=True)
|
53 |
+
|
54 |
+
k = str(img_dir / img_label_name)
|
55 |
+
if k not in labels:
|
56 |
+
labels[k] = []
|
57 |
+
labels[k].append(f"{cls} {x_center} {y_center} {width} {height}\n")
|
58 |
+
|
59 |
+
for k in labels:
|
60 |
+
with open(k, "w") as f:
|
61 |
+
f.writelines(labels[k])
|
62 |
+
|
63 |
+
|
64 |
+
# Download
|
65 |
+
dir = Path(yaml['path']) # dataset root dir
|
66 |
+
urls = ['https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip']
|
67 |
+
download(urls, dir=dir)
|
68 |
+
|
69 |
+
# Convert
|
70 |
+
annotations_dir = 'Argoverse-HD/annotations/'
|
71 |
+
(dir / 'Argoverse-1.1' / 'tracking').rename(dir / 'Argoverse-1.1' / 'images') # rename 'tracking' to 'images'
|
72 |
+
for d in "train.json", "val.json":
|
73 |
+
argoverse2yolo(dir / annotations_dir / d) # convert VisDrone annotations to YOLO labels
|
ultralytics/datasets/GlobalWheat2020.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# Global Wheat 2020 dataset http://www.global-wheat.com/ by University of Saskatchewan
|
3 |
+
# Example usage: yolo train data=GlobalWheat2020.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── GlobalWheat2020 ← downloads here (7.0 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/GlobalWheat2020 # dataset root dir
|
12 |
+
train: # train images (relative to 'path') 3422 images
|
13 |
+
- images/arvalis_1
|
14 |
+
- images/arvalis_2
|
15 |
+
- images/arvalis_3
|
16 |
+
- images/ethz_1
|
17 |
+
- images/rres_1
|
18 |
+
- images/inrae_1
|
19 |
+
- images/usask_1
|
20 |
+
val: # val images (relative to 'path') 748 images (WARNING: train set contains ethz_1)
|
21 |
+
- images/ethz_1
|
22 |
+
test: # test images (optional) 1276 images
|
23 |
+
- images/utokyo_1
|
24 |
+
- images/utokyo_2
|
25 |
+
- images/nau_1
|
26 |
+
- images/uq_1
|
27 |
+
|
28 |
+
# Classes
|
29 |
+
names:
|
30 |
+
0: wheat_head
|
31 |
+
|
32 |
+
|
33 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
34 |
+
download: |
|
35 |
+
from ultralytics.yolo.utils.downloads import download
|
36 |
+
from pathlib import Path
|
37 |
+
|
38 |
+
# Download
|
39 |
+
dir = Path(yaml['path']) # dataset root dir
|
40 |
+
urls = ['https://zenodo.org/record/4298502/files/global-wheat-codalab-official.zip',
|
41 |
+
'https://github.com/ultralytics/yolov5/releases/download/v1.0/GlobalWheat2020_labels.zip']
|
42 |
+
download(urls, dir=dir)
|
43 |
+
|
44 |
+
# Make Directories
|
45 |
+
for p in 'annotations', 'images', 'labels':
|
46 |
+
(dir / p).mkdir(parents=True, exist_ok=True)
|
47 |
+
|
48 |
+
# Move
|
49 |
+
for p in 'arvalis_1', 'arvalis_2', 'arvalis_3', 'ethz_1', 'rres_1', 'inrae_1', 'usask_1', \
|
50 |
+
'utokyo_1', 'utokyo_2', 'nau_1', 'uq_1':
|
51 |
+
(dir / 'global-wheat-codalab-official' / p).rename(dir / 'images' / p) # move to /images
|
52 |
+
f = (dir / 'global-wheat-codalab-official' / p).with_suffix('.json') # json file
|
53 |
+
if f.exists():
|
54 |
+
f.rename((dir / 'annotations' / p).with_suffix('.json')) # move to /annotations
|
ultralytics/datasets/ImageNet.yaml
ADDED
@@ -0,0 +1,2025 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# ImageNet-1k dataset https://www.image-net.org/index.php by Stanford University
|
3 |
+
# Simplified class names from https://github.com/anishathalye/imagenet-simple-labels
|
4 |
+
# Example usage: yolo train task=classify data=imagenet
|
5 |
+
# parent
|
6 |
+
# ├── ultralytics
|
7 |
+
# └── datasets
|
8 |
+
# └── imagenet ← downloads here (144 GB)
|
9 |
+
|
10 |
+
|
11 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
12 |
+
path: ../datasets/imagenet # dataset root dir
|
13 |
+
train: train # train images (relative to 'path') 1281167 images
|
14 |
+
val: val # val images (relative to 'path') 50000 images
|
15 |
+
test: # test images (optional)
|
16 |
+
|
17 |
+
# Classes
|
18 |
+
names:
|
19 |
+
0: tench
|
20 |
+
1: goldfish
|
21 |
+
2: great white shark
|
22 |
+
3: tiger shark
|
23 |
+
4: hammerhead shark
|
24 |
+
5: electric ray
|
25 |
+
6: stingray
|
26 |
+
7: cock
|
27 |
+
8: hen
|
28 |
+
9: ostrich
|
29 |
+
10: brambling
|
30 |
+
11: goldfinch
|
31 |
+
12: house finch
|
32 |
+
13: junco
|
33 |
+
14: indigo bunting
|
34 |
+
15: American robin
|
35 |
+
16: bulbul
|
36 |
+
17: jay
|
37 |
+
18: magpie
|
38 |
+
19: chickadee
|
39 |
+
20: American dipper
|
40 |
+
21: kite
|
41 |
+
22: bald eagle
|
42 |
+
23: vulture
|
43 |
+
24: great grey owl
|
44 |
+
25: fire salamander
|
45 |
+
26: smooth newt
|
46 |
+
27: newt
|
47 |
+
28: spotted salamander
|
48 |
+
29: axolotl
|
49 |
+
30: American bullfrog
|
50 |
+
31: tree frog
|
51 |
+
32: tailed frog
|
52 |
+
33: loggerhead sea turtle
|
53 |
+
34: leatherback sea turtle
|
54 |
+
35: mud turtle
|
55 |
+
36: terrapin
|
56 |
+
37: box turtle
|
57 |
+
38: banded gecko
|
58 |
+
39: green iguana
|
59 |
+
40: Carolina anole
|
60 |
+
41: desert grassland whiptail lizard
|
61 |
+
42: agama
|
62 |
+
43: frilled-necked lizard
|
63 |
+
44: alligator lizard
|
64 |
+
45: Gila monster
|
65 |
+
46: European green lizard
|
66 |
+
47: chameleon
|
67 |
+
48: Komodo dragon
|
68 |
+
49: Nile crocodile
|
69 |
+
50: American alligator
|
70 |
+
51: triceratops
|
71 |
+
52: worm snake
|
72 |
+
53: ring-necked snake
|
73 |
+
54: eastern hog-nosed snake
|
74 |
+
55: smooth green snake
|
75 |
+
56: kingsnake
|
76 |
+
57: garter snake
|
77 |
+
58: water snake
|
78 |
+
59: vine snake
|
79 |
+
60: night snake
|
80 |
+
61: boa constrictor
|
81 |
+
62: African rock python
|
82 |
+
63: Indian cobra
|
83 |
+
64: green mamba
|
84 |
+
65: sea snake
|
85 |
+
66: Saharan horned viper
|
86 |
+
67: eastern diamondback rattlesnake
|
87 |
+
68: sidewinder
|
88 |
+
69: trilobite
|
89 |
+
70: harvestman
|
90 |
+
71: scorpion
|
91 |
+
72: yellow garden spider
|
92 |
+
73: barn spider
|
93 |
+
74: European garden spider
|
94 |
+
75: southern black widow
|
95 |
+
76: tarantula
|
96 |
+
77: wolf spider
|
97 |
+
78: tick
|
98 |
+
79: centipede
|
99 |
+
80: black grouse
|
100 |
+
81: ptarmigan
|
101 |
+
82: ruffed grouse
|
102 |
+
83: prairie grouse
|
103 |
+
84: peacock
|
104 |
+
85: quail
|
105 |
+
86: partridge
|
106 |
+
87: grey parrot
|
107 |
+
88: macaw
|
108 |
+
89: sulphur-crested cockatoo
|
109 |
+
90: lorikeet
|
110 |
+
91: coucal
|
111 |
+
92: bee eater
|
112 |
+
93: hornbill
|
113 |
+
94: hummingbird
|
114 |
+
95: jacamar
|
115 |
+
96: toucan
|
116 |
+
97: duck
|
117 |
+
98: red-breasted merganser
|
118 |
+
99: goose
|
119 |
+
100: black swan
|
120 |
+
101: tusker
|
121 |
+
102: echidna
|
122 |
+
103: platypus
|
123 |
+
104: wallaby
|
124 |
+
105: koala
|
125 |
+
106: wombat
|
126 |
+
107: jellyfish
|
127 |
+
108: sea anemone
|
128 |
+
109: brain coral
|
129 |
+
110: flatworm
|
130 |
+
111: nematode
|
131 |
+
112: conch
|
132 |
+
113: snail
|
133 |
+
114: slug
|
134 |
+
115: sea slug
|
135 |
+
116: chiton
|
136 |
+
117: chambered nautilus
|
137 |
+
118: Dungeness crab
|
138 |
+
119: rock crab
|
139 |
+
120: fiddler crab
|
140 |
+
121: red king crab
|
141 |
+
122: American lobster
|
142 |
+
123: spiny lobster
|
143 |
+
124: crayfish
|
144 |
+
125: hermit crab
|
145 |
+
126: isopod
|
146 |
+
127: white stork
|
147 |
+
128: black stork
|
148 |
+
129: spoonbill
|
149 |
+
130: flamingo
|
150 |
+
131: little blue heron
|
151 |
+
132: great egret
|
152 |
+
133: bittern
|
153 |
+
134: crane (bird)
|
154 |
+
135: limpkin
|
155 |
+
136: common gallinule
|
156 |
+
137: American coot
|
157 |
+
138: bustard
|
158 |
+
139: ruddy turnstone
|
159 |
+
140: dunlin
|
160 |
+
141: common redshank
|
161 |
+
142: dowitcher
|
162 |
+
143: oystercatcher
|
163 |
+
144: pelican
|
164 |
+
145: king penguin
|
165 |
+
146: albatross
|
166 |
+
147: grey whale
|
167 |
+
148: killer whale
|
168 |
+
149: dugong
|
169 |
+
150: sea lion
|
170 |
+
151: Chihuahua
|
171 |
+
152: Japanese Chin
|
172 |
+
153: Maltese
|
173 |
+
154: Pekingese
|
174 |
+
155: Shih Tzu
|
175 |
+
156: King Charles Spaniel
|
176 |
+
157: Papillon
|
177 |
+
158: toy terrier
|
178 |
+
159: Rhodesian Ridgeback
|
179 |
+
160: Afghan Hound
|
180 |
+
161: Basset Hound
|
181 |
+
162: Beagle
|
182 |
+
163: Bloodhound
|
183 |
+
164: Bluetick Coonhound
|
184 |
+
165: Black and Tan Coonhound
|
185 |
+
166: Treeing Walker Coonhound
|
186 |
+
167: English foxhound
|
187 |
+
168: Redbone Coonhound
|
188 |
+
169: borzoi
|
189 |
+
170: Irish Wolfhound
|
190 |
+
171: Italian Greyhound
|
191 |
+
172: Whippet
|
192 |
+
173: Ibizan Hound
|
193 |
+
174: Norwegian Elkhound
|
194 |
+
175: Otterhound
|
195 |
+
176: Saluki
|
196 |
+
177: Scottish Deerhound
|
197 |
+
178: Weimaraner
|
198 |
+
179: Staffordshire Bull Terrier
|
199 |
+
180: American Staffordshire Terrier
|
200 |
+
181: Bedlington Terrier
|
201 |
+
182: Border Terrier
|
202 |
+
183: Kerry Blue Terrier
|
203 |
+
184: Irish Terrier
|
204 |
+
185: Norfolk Terrier
|
205 |
+
186: Norwich Terrier
|
206 |
+
187: Yorkshire Terrier
|
207 |
+
188: Wire Fox Terrier
|
208 |
+
189: Lakeland Terrier
|
209 |
+
190: Sealyham Terrier
|
210 |
+
191: Airedale Terrier
|
211 |
+
192: Cairn Terrier
|
212 |
+
193: Australian Terrier
|
213 |
+
194: Dandie Dinmont Terrier
|
214 |
+
195: Boston Terrier
|
215 |
+
196: Miniature Schnauzer
|
216 |
+
197: Giant Schnauzer
|
217 |
+
198: Standard Schnauzer
|
218 |
+
199: Scottish Terrier
|
219 |
+
200: Tibetan Terrier
|
220 |
+
201: Australian Silky Terrier
|
221 |
+
202: Soft-coated Wheaten Terrier
|
222 |
+
203: West Highland White Terrier
|
223 |
+
204: Lhasa Apso
|
224 |
+
205: Flat-Coated Retriever
|
225 |
+
206: Curly-coated Retriever
|
226 |
+
207: Golden Retriever
|
227 |
+
208: Labrador Retriever
|
228 |
+
209: Chesapeake Bay Retriever
|
229 |
+
210: German Shorthaired Pointer
|
230 |
+
211: Vizsla
|
231 |
+
212: English Setter
|
232 |
+
213: Irish Setter
|
233 |
+
214: Gordon Setter
|
234 |
+
215: Brittany
|
235 |
+
216: Clumber Spaniel
|
236 |
+
217: English Springer Spaniel
|
237 |
+
218: Welsh Springer Spaniel
|
238 |
+
219: Cocker Spaniels
|
239 |
+
220: Sussex Spaniel
|
240 |
+
221: Irish Water Spaniel
|
241 |
+
222: Kuvasz
|
242 |
+
223: Schipperke
|
243 |
+
224: Groenendael
|
244 |
+
225: Malinois
|
245 |
+
226: Briard
|
246 |
+
227: Australian Kelpie
|
247 |
+
228: Komondor
|
248 |
+
229: Old English Sheepdog
|
249 |
+
230: Shetland Sheepdog
|
250 |
+
231: collie
|
251 |
+
232: Border Collie
|
252 |
+
233: Bouvier des Flandres
|
253 |
+
234: Rottweiler
|
254 |
+
235: German Shepherd Dog
|
255 |
+
236: Dobermann
|
256 |
+
237: Miniature Pinscher
|
257 |
+
238: Greater Swiss Mountain Dog
|
258 |
+
239: Bernese Mountain Dog
|
259 |
+
240: Appenzeller Sennenhund
|
260 |
+
241: Entlebucher Sennenhund
|
261 |
+
242: Boxer
|
262 |
+
243: Bullmastiff
|
263 |
+
244: Tibetan Mastiff
|
264 |
+
245: French Bulldog
|
265 |
+
246: Great Dane
|
266 |
+
247: St. Bernard
|
267 |
+
248: husky
|
268 |
+
249: Alaskan Malamute
|
269 |
+
250: Siberian Husky
|
270 |
+
251: Dalmatian
|
271 |
+
252: Affenpinscher
|
272 |
+
253: Basenji
|
273 |
+
254: pug
|
274 |
+
255: Leonberger
|
275 |
+
256: Newfoundland
|
276 |
+
257: Pyrenean Mountain Dog
|
277 |
+
258: Samoyed
|
278 |
+
259: Pomeranian
|
279 |
+
260: Chow Chow
|
280 |
+
261: Keeshond
|
281 |
+
262: Griffon Bruxellois
|
282 |
+
263: Pembroke Welsh Corgi
|
283 |
+
264: Cardigan Welsh Corgi
|
284 |
+
265: Toy Poodle
|
285 |
+
266: Miniature Poodle
|
286 |
+
267: Standard Poodle
|
287 |
+
268: Mexican hairless dog
|
288 |
+
269: grey wolf
|
289 |
+
270: Alaskan tundra wolf
|
290 |
+
271: red wolf
|
291 |
+
272: coyote
|
292 |
+
273: dingo
|
293 |
+
274: dhole
|
294 |
+
275: African wild dog
|
295 |
+
276: hyena
|
296 |
+
277: red fox
|
297 |
+
278: kit fox
|
298 |
+
279: Arctic fox
|
299 |
+
280: grey fox
|
300 |
+
281: tabby cat
|
301 |
+
282: tiger cat
|
302 |
+
283: Persian cat
|
303 |
+
284: Siamese cat
|
304 |
+
285: Egyptian Mau
|
305 |
+
286: cougar
|
306 |
+
287: lynx
|
307 |
+
288: leopard
|
308 |
+
289: snow leopard
|
309 |
+
290: jaguar
|
310 |
+
291: lion
|
311 |
+
292: tiger
|
312 |
+
293: cheetah
|
313 |
+
294: brown bear
|
314 |
+
295: American black bear
|
315 |
+
296: polar bear
|
316 |
+
297: sloth bear
|
317 |
+
298: mongoose
|
318 |
+
299: meerkat
|
319 |
+
300: tiger beetle
|
320 |
+
301: ladybug
|
321 |
+
302: ground beetle
|
322 |
+
303: longhorn beetle
|
323 |
+
304: leaf beetle
|
324 |
+
305: dung beetle
|
325 |
+
306: rhinoceros beetle
|
326 |
+
307: weevil
|
327 |
+
308: fly
|
328 |
+
309: bee
|
329 |
+
310: ant
|
330 |
+
311: grasshopper
|
331 |
+
312: cricket
|
332 |
+
313: stick insect
|
333 |
+
314: cockroach
|
334 |
+
315: mantis
|
335 |
+
316: cicada
|
336 |
+
317: leafhopper
|
337 |
+
318: lacewing
|
338 |
+
319: dragonfly
|
339 |
+
320: damselfly
|
340 |
+
321: red admiral
|
341 |
+
322: ringlet
|
342 |
+
323: monarch butterfly
|
343 |
+
324: small white
|
344 |
+
325: sulphur butterfly
|
345 |
+
326: gossamer-winged butterfly
|
346 |
+
327: starfish
|
347 |
+
328: sea urchin
|
348 |
+
329: sea cucumber
|
349 |
+
330: cottontail rabbit
|
350 |
+
331: hare
|
351 |
+
332: Angora rabbit
|
352 |
+
333: hamster
|
353 |
+
334: porcupine
|
354 |
+
335: fox squirrel
|
355 |
+
336: marmot
|
356 |
+
337: beaver
|
357 |
+
338: guinea pig
|
358 |
+
339: common sorrel
|
359 |
+
340: zebra
|
360 |
+
341: pig
|
361 |
+
342: wild boar
|
362 |
+
343: warthog
|
363 |
+
344: hippopotamus
|
364 |
+
345: ox
|
365 |
+
346: water buffalo
|
366 |
+
347: bison
|
367 |
+
348: ram
|
368 |
+
349: bighorn sheep
|
369 |
+
350: Alpine ibex
|
370 |
+
351: hartebeest
|
371 |
+
352: impala
|
372 |
+
353: gazelle
|
373 |
+
354: dromedary
|
374 |
+
355: llama
|
375 |
+
356: weasel
|
376 |
+
357: mink
|
377 |
+
358: European polecat
|
378 |
+
359: black-footed ferret
|
379 |
+
360: otter
|
380 |
+
361: skunk
|
381 |
+
362: badger
|
382 |
+
363: armadillo
|
383 |
+
364: three-toed sloth
|
384 |
+
365: orangutan
|
385 |
+
366: gorilla
|
386 |
+
367: chimpanzee
|
387 |
+
368: gibbon
|
388 |
+
369: siamang
|
389 |
+
370: guenon
|
390 |
+
371: patas monkey
|
391 |
+
372: baboon
|
392 |
+
373: macaque
|
393 |
+
374: langur
|
394 |
+
375: black-and-white colobus
|
395 |
+
376: proboscis monkey
|
396 |
+
377: marmoset
|
397 |
+
378: white-headed capuchin
|
398 |
+
379: howler monkey
|
399 |
+
380: titi
|
400 |
+
381: Geoffroy's spider monkey
|
401 |
+
382: common squirrel monkey
|
402 |
+
383: ring-tailed lemur
|
403 |
+
384: indri
|
404 |
+
385: Asian elephant
|
405 |
+
386: African bush elephant
|
406 |
+
387: red panda
|
407 |
+
388: giant panda
|
408 |
+
389: snoek
|
409 |
+
390: eel
|
410 |
+
391: coho salmon
|
411 |
+
392: rock beauty
|
412 |
+
393: clownfish
|
413 |
+
394: sturgeon
|
414 |
+
395: garfish
|
415 |
+
396: lionfish
|
416 |
+
397: pufferfish
|
417 |
+
398: abacus
|
418 |
+
399: abaya
|
419 |
+
400: academic gown
|
420 |
+
401: accordion
|
421 |
+
402: acoustic guitar
|
422 |
+
403: aircraft carrier
|
423 |
+
404: airliner
|
424 |
+
405: airship
|
425 |
+
406: altar
|
426 |
+
407: ambulance
|
427 |
+
408: amphibious vehicle
|
428 |
+
409: analog clock
|
429 |
+
410: apiary
|
430 |
+
411: apron
|
431 |
+
412: waste container
|
432 |
+
413: assault rifle
|
433 |
+
414: backpack
|
434 |
+
415: bakery
|
435 |
+
416: balance beam
|
436 |
+
417: balloon
|
437 |
+
418: ballpoint pen
|
438 |
+
419: Band-Aid
|
439 |
+
420: banjo
|
440 |
+
421: baluster
|
441 |
+
422: barbell
|
442 |
+
423: barber chair
|
443 |
+
424: barbershop
|
444 |
+
425: barn
|
445 |
+
426: barometer
|
446 |
+
427: barrel
|
447 |
+
428: wheelbarrow
|
448 |
+
429: baseball
|
449 |
+
430: basketball
|
450 |
+
431: bassinet
|
451 |
+
432: bassoon
|
452 |
+
433: swimming cap
|
453 |
+
434: bath towel
|
454 |
+
435: bathtub
|
455 |
+
436: station wagon
|
456 |
+
437: lighthouse
|
457 |
+
438: beaker
|
458 |
+
439: military cap
|
459 |
+
440: beer bottle
|
460 |
+
441: beer glass
|
461 |
+
442: bell-cot
|
462 |
+
443: bib
|
463 |
+
444: tandem bicycle
|
464 |
+
445: bikini
|
465 |
+
446: ring binder
|
466 |
+
447: binoculars
|
467 |
+
448: birdhouse
|
468 |
+
449: boathouse
|
469 |
+
450: bobsleigh
|
470 |
+
451: bolo tie
|
471 |
+
452: poke bonnet
|
472 |
+
453: bookcase
|
473 |
+
454: bookstore
|
474 |
+
455: bottle cap
|
475 |
+
456: bow
|
476 |
+
457: bow tie
|
477 |
+
458: brass
|
478 |
+
459: bra
|
479 |
+
460: breakwater
|
480 |
+
461: breastplate
|
481 |
+
462: broom
|
482 |
+
463: bucket
|
483 |
+
464: buckle
|
484 |
+
465: bulletproof vest
|
485 |
+
466: high-speed train
|
486 |
+
467: butcher shop
|
487 |
+
468: taxicab
|
488 |
+
469: cauldron
|
489 |
+
470: candle
|
490 |
+
471: cannon
|
491 |
+
472: canoe
|
492 |
+
473: can opener
|
493 |
+
474: cardigan
|
494 |
+
475: car mirror
|
495 |
+
476: carousel
|
496 |
+
477: tool kit
|
497 |
+
478: carton
|
498 |
+
479: car wheel
|
499 |
+
480: automated teller machine
|
500 |
+
481: cassette
|
501 |
+
482: cassette player
|
502 |
+
483: castle
|
503 |
+
484: catamaran
|
504 |
+
485: CD player
|
505 |
+
486: cello
|
506 |
+
487: mobile phone
|
507 |
+
488: chain
|
508 |
+
489: chain-link fence
|
509 |
+
490: chain mail
|
510 |
+
491: chainsaw
|
511 |
+
492: chest
|
512 |
+
493: chiffonier
|
513 |
+
494: chime
|
514 |
+
495: china cabinet
|
515 |
+
496: Christmas stocking
|
516 |
+
497: church
|
517 |
+
498: movie theater
|
518 |
+
499: cleaver
|
519 |
+
500: cliff dwelling
|
520 |
+
501: cloak
|
521 |
+
502: clogs
|
522 |
+
503: cocktail shaker
|
523 |
+
504: coffee mug
|
524 |
+
505: coffeemaker
|
525 |
+
506: coil
|
526 |
+
507: combination lock
|
527 |
+
508: computer keyboard
|
528 |
+
509: confectionery store
|
529 |
+
510: container ship
|
530 |
+
511: convertible
|
531 |
+
512: corkscrew
|
532 |
+
513: cornet
|
533 |
+
514: cowboy boot
|
534 |
+
515: cowboy hat
|
535 |
+
516: cradle
|
536 |
+
517: crane (machine)
|
537 |
+
518: crash helmet
|
538 |
+
519: crate
|
539 |
+
520: infant bed
|
540 |
+
521: Crock Pot
|
541 |
+
522: croquet ball
|
542 |
+
523: crutch
|
543 |
+
524: cuirass
|
544 |
+
525: dam
|
545 |
+
526: desk
|
546 |
+
527: desktop computer
|
547 |
+
528: rotary dial telephone
|
548 |
+
529: diaper
|
549 |
+
530: digital clock
|
550 |
+
531: digital watch
|
551 |
+
532: dining table
|
552 |
+
533: dishcloth
|
553 |
+
534: dishwasher
|
554 |
+
535: disc brake
|
555 |
+
536: dock
|
556 |
+
537: dog sled
|
557 |
+
538: dome
|
558 |
+
539: doormat
|
559 |
+
540: drilling rig
|
560 |
+
541: drum
|
561 |
+
542: drumstick
|
562 |
+
543: dumbbell
|
563 |
+
544: Dutch oven
|
564 |
+
545: electric fan
|
565 |
+
546: electric guitar
|
566 |
+
547: electric locomotive
|
567 |
+
548: entertainment center
|
568 |
+
549: envelope
|
569 |
+
550: espresso machine
|
570 |
+
551: face powder
|
571 |
+
552: feather boa
|
572 |
+
553: filing cabinet
|
573 |
+
554: fireboat
|
574 |
+
555: fire engine
|
575 |
+
556: fire screen sheet
|
576 |
+
557: flagpole
|
577 |
+
558: flute
|
578 |
+
559: folding chair
|
579 |
+
560: football helmet
|
580 |
+
561: forklift
|
581 |
+
562: fountain
|
582 |
+
563: fountain pen
|
583 |
+
564: four-poster bed
|
584 |
+
565: freight car
|
585 |
+
566: French horn
|
586 |
+
567: frying pan
|
587 |
+
568: fur coat
|
588 |
+
569: garbage truck
|
589 |
+
570: gas mask
|
590 |
+
571: gas pump
|
591 |
+
572: goblet
|
592 |
+
573: go-kart
|
593 |
+
574: golf ball
|
594 |
+
575: golf cart
|
595 |
+
576: gondola
|
596 |
+
577: gong
|
597 |
+
578: gown
|
598 |
+
579: grand piano
|
599 |
+
580: greenhouse
|
600 |
+
581: grille
|
601 |
+
582: grocery store
|
602 |
+
583: guillotine
|
603 |
+
584: barrette
|
604 |
+
585: hair spray
|
605 |
+
586: half-track
|
606 |
+
587: hammer
|
607 |
+
588: hamper
|
608 |
+
589: hair dryer
|
609 |
+
590: hand-held computer
|
610 |
+
591: handkerchief
|
611 |
+
592: hard disk drive
|
612 |
+
593: harmonica
|
613 |
+
594: harp
|
614 |
+
595: harvester
|
615 |
+
596: hatchet
|
616 |
+
597: holster
|
617 |
+
598: home theater
|
618 |
+
599: honeycomb
|
619 |
+
600: hook
|
620 |
+
601: hoop skirt
|
621 |
+
602: horizontal bar
|
622 |
+
603: horse-drawn vehicle
|
623 |
+
604: hourglass
|
624 |
+
605: iPod
|
625 |
+
606: clothes iron
|
626 |
+
607: jack-o'-lantern
|
627 |
+
608: jeans
|
628 |
+
609: jeep
|
629 |
+
610: T-shirt
|
630 |
+
611: jigsaw puzzle
|
631 |
+
612: pulled rickshaw
|
632 |
+
613: joystick
|
633 |
+
614: kimono
|
634 |
+
615: knee pad
|
635 |
+
616: knot
|
636 |
+
617: lab coat
|
637 |
+
618: ladle
|
638 |
+
619: lampshade
|
639 |
+
620: laptop computer
|
640 |
+
621: lawn mower
|
641 |
+
622: lens cap
|
642 |
+
623: paper knife
|
643 |
+
624: library
|
644 |
+
625: lifeboat
|
645 |
+
626: lighter
|
646 |
+
627: limousine
|
647 |
+
628: ocean liner
|
648 |
+
629: lipstick
|
649 |
+
630: slip-on shoe
|
650 |
+
631: lotion
|
651 |
+
632: speaker
|
652 |
+
633: loupe
|
653 |
+
634: sawmill
|
654 |
+
635: magnetic compass
|
655 |
+
636: mail bag
|
656 |
+
637: mailbox
|
657 |
+
638: tights
|
658 |
+
639: tank suit
|
659 |
+
640: manhole cover
|
660 |
+
641: maraca
|
661 |
+
642: marimba
|
662 |
+
643: mask
|
663 |
+
644: match
|
664 |
+
645: maypole
|
665 |
+
646: maze
|
666 |
+
647: measuring cup
|
667 |
+
648: medicine chest
|
668 |
+
649: megalith
|
669 |
+
650: microphone
|
670 |
+
651: microwave oven
|
671 |
+
652: military uniform
|
672 |
+
653: milk can
|
673 |
+
654: minibus
|
674 |
+
655: miniskirt
|
675 |
+
656: minivan
|
676 |
+
657: missile
|
677 |
+
658: mitten
|
678 |
+
659: mixing bowl
|
679 |
+
660: mobile home
|
680 |
+
661: Model T
|
681 |
+
662: modem
|
682 |
+
663: monastery
|
683 |
+
664: monitor
|
684 |
+
665: moped
|
685 |
+
666: mortar
|
686 |
+
667: square academic cap
|
687 |
+
668: mosque
|
688 |
+
669: mosquito net
|
689 |
+
670: scooter
|
690 |
+
671: mountain bike
|
691 |
+
672: tent
|
692 |
+
673: computer mouse
|
693 |
+
674: mousetrap
|
694 |
+
675: moving van
|
695 |
+
676: muzzle
|
696 |
+
677: nail
|
697 |
+
678: neck brace
|
698 |
+
679: necklace
|
699 |
+
680: nipple
|
700 |
+
681: notebook computer
|
701 |
+
682: obelisk
|
702 |
+
683: oboe
|
703 |
+
684: ocarina
|
704 |
+
685: odometer
|
705 |
+
686: oil filter
|
706 |
+
687: organ
|
707 |
+
688: oscilloscope
|
708 |
+
689: overskirt
|
709 |
+
690: bullock cart
|
710 |
+
691: oxygen mask
|
711 |
+
692: packet
|
712 |
+
693: paddle
|
713 |
+
694: paddle wheel
|
714 |
+
695: padlock
|
715 |
+
696: paintbrush
|
716 |
+
697: pajamas
|
717 |
+
698: palace
|
718 |
+
699: pan flute
|
719 |
+
700: paper towel
|
720 |
+
701: parachute
|
721 |
+
702: parallel bars
|
722 |
+
703: park bench
|
723 |
+
704: parking meter
|
724 |
+
705: passenger car
|
725 |
+
706: patio
|
726 |
+
707: payphone
|
727 |
+
708: pedestal
|
728 |
+
709: pencil case
|
729 |
+
710: pencil sharpener
|
730 |
+
711: perfume
|
731 |
+
712: Petri dish
|
732 |
+
713: photocopier
|
733 |
+
714: plectrum
|
734 |
+
715: Pickelhaube
|
735 |
+
716: picket fence
|
736 |
+
717: pickup truck
|
737 |
+
718: pier
|
738 |
+
719: piggy bank
|
739 |
+
720: pill bottle
|
740 |
+
721: pillow
|
741 |
+
722: ping-pong ball
|
742 |
+
723: pinwheel
|
743 |
+
724: pirate ship
|
744 |
+
725: pitcher
|
745 |
+
726: hand plane
|
746 |
+
727: planetarium
|
747 |
+
728: plastic bag
|
748 |
+
729: plate rack
|
749 |
+
730: plow
|
750 |
+
731: plunger
|
751 |
+
732: Polaroid camera
|
752 |
+
733: pole
|
753 |
+
734: police van
|
754 |
+
735: poncho
|
755 |
+
736: billiard table
|
756 |
+
737: soda bottle
|
757 |
+
738: pot
|
758 |
+
739: potter's wheel
|
759 |
+
740: power drill
|
760 |
+
741: prayer rug
|
761 |
+
742: printer
|
762 |
+
743: prison
|
763 |
+
744: projectile
|
764 |
+
745: projector
|
765 |
+
746: hockey puck
|
766 |
+
747: punching bag
|
767 |
+
748: purse
|
768 |
+
749: quill
|
769 |
+
750: quilt
|
770 |
+
751: race car
|
771 |
+
752: racket
|
772 |
+
753: radiator
|
773 |
+
754: radio
|
774 |
+
755: radio telescope
|
775 |
+
756: rain barrel
|
776 |
+
757: recreational vehicle
|
777 |
+
758: reel
|
778 |
+
759: reflex camera
|
779 |
+
760: refrigerator
|
780 |
+
761: remote control
|
781 |
+
762: restaurant
|
782 |
+
763: revolver
|
783 |
+
764: rifle
|
784 |
+
765: rocking chair
|
785 |
+
766: rotisserie
|
786 |
+
767: eraser
|
787 |
+
768: rugby ball
|
788 |
+
769: ruler
|
789 |
+
770: running shoe
|
790 |
+
771: safe
|
791 |
+
772: safety pin
|
792 |
+
773: salt shaker
|
793 |
+
774: sandal
|
794 |
+
775: sarong
|
795 |
+
776: saxophone
|
796 |
+
777: scabbard
|
797 |
+
778: weighing scale
|
798 |
+
779: school bus
|
799 |
+
780: schooner
|
800 |
+
781: scoreboard
|
801 |
+
782: CRT screen
|
802 |
+
783: screw
|
803 |
+
784: screwdriver
|
804 |
+
785: seat belt
|
805 |
+
786: sewing machine
|
806 |
+
787: shield
|
807 |
+
788: shoe store
|
808 |
+
789: shoji
|
809 |
+
790: shopping basket
|
810 |
+
791: shopping cart
|
811 |
+
792: shovel
|
812 |
+
793: shower cap
|
813 |
+
794: shower curtain
|
814 |
+
795: ski
|
815 |
+
796: ski mask
|
816 |
+
797: sleeping bag
|
817 |
+
798: slide rule
|
818 |
+
799: sliding door
|
819 |
+
800: slot machine
|
820 |
+
801: snorkel
|
821 |
+
802: snowmobile
|
822 |
+
803: snowplow
|
823 |
+
804: soap dispenser
|
824 |
+
805: soccer ball
|
825 |
+
806: sock
|
826 |
+
807: solar thermal collector
|
827 |
+
808: sombrero
|
828 |
+
809: soup bowl
|
829 |
+
810: space bar
|
830 |
+
811: space heater
|
831 |
+
812: space shuttle
|
832 |
+
813: spatula
|
833 |
+
814: motorboat
|
834 |
+
815: spider web
|
835 |
+
816: spindle
|
836 |
+
817: sports car
|
837 |
+
818: spotlight
|
838 |
+
819: stage
|
839 |
+
820: steam locomotive
|
840 |
+
821: through arch bridge
|
841 |
+
822: steel drum
|
842 |
+
823: stethoscope
|
843 |
+
824: scarf
|
844 |
+
825: stone wall
|
845 |
+
826: stopwatch
|
846 |
+
827: stove
|
847 |
+
828: strainer
|
848 |
+
829: tram
|
849 |
+
830: stretcher
|
850 |
+
831: couch
|
851 |
+
832: stupa
|
852 |
+
833: submarine
|
853 |
+
834: suit
|
854 |
+
835: sundial
|
855 |
+
836: sunglass
|
856 |
+
837: sunglasses
|
857 |
+
838: sunscreen
|
858 |
+
839: suspension bridge
|
859 |
+
840: mop
|
860 |
+
841: sweatshirt
|
861 |
+
842: swimsuit
|
862 |
+
843: swing
|
863 |
+
844: switch
|
864 |
+
845: syringe
|
865 |
+
846: table lamp
|
866 |
+
847: tank
|
867 |
+
848: tape player
|
868 |
+
849: teapot
|
869 |
+
850: teddy bear
|
870 |
+
851: television
|
871 |
+
852: tennis ball
|
872 |
+
853: thatched roof
|
873 |
+
854: front curtain
|
874 |
+
855: thimble
|
875 |
+
856: threshing machine
|
876 |
+
857: throne
|
877 |
+
858: tile roof
|
878 |
+
859: toaster
|
879 |
+
860: tobacco shop
|
880 |
+
861: toilet seat
|
881 |
+
862: torch
|
882 |
+
863: totem pole
|
883 |
+
864: tow truck
|
884 |
+
865: toy store
|
885 |
+
866: tractor
|
886 |
+
867: semi-trailer truck
|
887 |
+
868: tray
|
888 |
+
869: trench coat
|
889 |
+
870: tricycle
|
890 |
+
871: trimaran
|
891 |
+
872: tripod
|
892 |
+
873: triumphal arch
|
893 |
+
874: trolleybus
|
894 |
+
875: trombone
|
895 |
+
876: tub
|
896 |
+
877: turnstile
|
897 |
+
878: typewriter keyboard
|
898 |
+
879: umbrella
|
899 |
+
880: unicycle
|
900 |
+
881: upright piano
|
901 |
+
882: vacuum cleaner
|
902 |
+
883: vase
|
903 |
+
884: vault
|
904 |
+
885: velvet
|
905 |
+
886: vending machine
|
906 |
+
887: vestment
|
907 |
+
888: viaduct
|
908 |
+
889: violin
|
909 |
+
890: volleyball
|
910 |
+
891: waffle iron
|
911 |
+
892: wall clock
|
912 |
+
893: wallet
|
913 |
+
894: wardrobe
|
914 |
+
895: military aircraft
|
915 |
+
896: sink
|
916 |
+
897: washing machine
|
917 |
+
898: water bottle
|
918 |
+
899: water jug
|
919 |
+
900: water tower
|
920 |
+
901: whiskey jug
|
921 |
+
902: whistle
|
922 |
+
903: wig
|
923 |
+
904: window screen
|
924 |
+
905: window shade
|
925 |
+
906: Windsor tie
|
926 |
+
907: wine bottle
|
927 |
+
908: wing
|
928 |
+
909: wok
|
929 |
+
910: wooden spoon
|
930 |
+
911: wool
|
931 |
+
912: split-rail fence
|
932 |
+
913: shipwreck
|
933 |
+
914: yawl
|
934 |
+
915: yurt
|
935 |
+
916: website
|
936 |
+
917: comic book
|
937 |
+
918: crossword
|
938 |
+
919: traffic sign
|
939 |
+
920: traffic light
|
940 |
+
921: dust jacket
|
941 |
+
922: menu
|
942 |
+
923: plate
|
943 |
+
924: guacamole
|
944 |
+
925: consomme
|
945 |
+
926: hot pot
|
946 |
+
927: trifle
|
947 |
+
928: ice cream
|
948 |
+
929: ice pop
|
949 |
+
930: baguette
|
950 |
+
931: bagel
|
951 |
+
932: pretzel
|
952 |
+
933: cheeseburger
|
953 |
+
934: hot dog
|
954 |
+
935: mashed potato
|
955 |
+
936: cabbage
|
956 |
+
937: broccoli
|
957 |
+
938: cauliflower
|
958 |
+
939: zucchini
|
959 |
+
940: spaghetti squash
|
960 |
+
941: acorn squash
|
961 |
+
942: butternut squash
|
962 |
+
943: cucumber
|
963 |
+
944: artichoke
|
964 |
+
945: bell pepper
|
965 |
+
946: cardoon
|
966 |
+
947: mushroom
|
967 |
+
948: Granny Smith
|
968 |
+
949: strawberry
|
969 |
+
950: orange
|
970 |
+
951: lemon
|
971 |
+
952: fig
|
972 |
+
953: pineapple
|
973 |
+
954: banana
|
974 |
+
955: jackfruit
|
975 |
+
956: custard apple
|
976 |
+
957: pomegranate
|
977 |
+
958: hay
|
978 |
+
959: carbonara
|
979 |
+
960: chocolate syrup
|
980 |
+
961: dough
|
981 |
+
962: meatloaf
|
982 |
+
963: pizza
|
983 |
+
964: pot pie
|
984 |
+
965: burrito
|
985 |
+
966: red wine
|
986 |
+
967: espresso
|
987 |
+
968: cup
|
988 |
+
969: eggnog
|
989 |
+
970: alp
|
990 |
+
971: bubble
|
991 |
+
972: cliff
|
992 |
+
973: coral reef
|
993 |
+
974: geyser
|
994 |
+
975: lakeshore
|
995 |
+
976: promontory
|
996 |
+
977: shoal
|
997 |
+
978: seashore
|
998 |
+
979: valley
|
999 |
+
980: volcano
|
1000 |
+
981: baseball player
|
1001 |
+
982: bridegroom
|
1002 |
+
983: scuba diver
|
1003 |
+
984: rapeseed
|
1004 |
+
985: daisy
|
1005 |
+
986: yellow lady's slipper
|
1006 |
+
987: corn
|
1007 |
+
988: acorn
|
1008 |
+
989: rose hip
|
1009 |
+
990: horse chestnut seed
|
1010 |
+
991: coral fungus
|
1011 |
+
992: agaric
|
1012 |
+
993: gyromitra
|
1013 |
+
994: stinkhorn mushroom
|
1014 |
+
995: earth star
|
1015 |
+
996: hen-of-the-woods
|
1016 |
+
997: bolete
|
1017 |
+
998: ear
|
1018 |
+
999: toilet paper
|
1019 |
+
|
1020 |
+
# Imagenet class codes to human-readable names
|
1021 |
+
map:
|
1022 |
+
n01440764: tench
|
1023 |
+
n01443537: goldfish
|
1024 |
+
n01484850: great_white_shark
|
1025 |
+
n01491361: tiger_shark
|
1026 |
+
n01494475: hammerhead
|
1027 |
+
n01496331: electric_ray
|
1028 |
+
n01498041: stingray
|
1029 |
+
n01514668: cock
|
1030 |
+
n01514859: hen
|
1031 |
+
n01518878: ostrich
|
1032 |
+
n01530575: brambling
|
1033 |
+
n01531178: goldfinch
|
1034 |
+
n01532829: house_finch
|
1035 |
+
n01534433: junco
|
1036 |
+
n01537544: indigo_bunting
|
1037 |
+
n01558993: robin
|
1038 |
+
n01560419: bulbul
|
1039 |
+
n01580077: jay
|
1040 |
+
n01582220: magpie
|
1041 |
+
n01592084: chickadee
|
1042 |
+
n01601694: water_ouzel
|
1043 |
+
n01608432: kite
|
1044 |
+
n01614925: bald_eagle
|
1045 |
+
n01616318: vulture
|
1046 |
+
n01622779: great_grey_owl
|
1047 |
+
n01629819: European_fire_salamander
|
1048 |
+
n01630670: common_newt
|
1049 |
+
n01631663: eft
|
1050 |
+
n01632458: spotted_salamander
|
1051 |
+
n01632777: axolotl
|
1052 |
+
n01641577: bullfrog
|
1053 |
+
n01644373: tree_frog
|
1054 |
+
n01644900: tailed_frog
|
1055 |
+
n01664065: loggerhead
|
1056 |
+
n01665541: leatherback_turtle
|
1057 |
+
n01667114: mud_turtle
|
1058 |
+
n01667778: terrapin
|
1059 |
+
n01669191: box_turtle
|
1060 |
+
n01675722: banded_gecko
|
1061 |
+
n01677366: common_iguana
|
1062 |
+
n01682714: American_chameleon
|
1063 |
+
n01685808: whiptail
|
1064 |
+
n01687978: agama
|
1065 |
+
n01688243: frilled_lizard
|
1066 |
+
n01689811: alligator_lizard
|
1067 |
+
n01692333: Gila_monster
|
1068 |
+
n01693334: green_lizard
|
1069 |
+
n01694178: African_chameleon
|
1070 |
+
n01695060: Komodo_dragon
|
1071 |
+
n01697457: African_crocodile
|
1072 |
+
n01698640: American_alligator
|
1073 |
+
n01704323: triceratops
|
1074 |
+
n01728572: thunder_snake
|
1075 |
+
n01728920: ringneck_snake
|
1076 |
+
n01729322: hognose_snake
|
1077 |
+
n01729977: green_snake
|
1078 |
+
n01734418: king_snake
|
1079 |
+
n01735189: garter_snake
|
1080 |
+
n01737021: water_snake
|
1081 |
+
n01739381: vine_snake
|
1082 |
+
n01740131: night_snake
|
1083 |
+
n01742172: boa_constrictor
|
1084 |
+
n01744401: rock_python
|
1085 |
+
n01748264: Indian_cobra
|
1086 |
+
n01749939: green_mamba
|
1087 |
+
n01751748: sea_snake
|
1088 |
+
n01753488: horned_viper
|
1089 |
+
n01755581: diamondback
|
1090 |
+
n01756291: sidewinder
|
1091 |
+
n01768244: trilobite
|
1092 |
+
n01770081: harvestman
|
1093 |
+
n01770393: scorpion
|
1094 |
+
n01773157: black_and_gold_garden_spider
|
1095 |
+
n01773549: barn_spider
|
1096 |
+
n01773797: garden_spider
|
1097 |
+
n01774384: black_widow
|
1098 |
+
n01774750: tarantula
|
1099 |
+
n01775062: wolf_spider
|
1100 |
+
n01776313: tick
|
1101 |
+
n01784675: centipede
|
1102 |
+
n01795545: black_grouse
|
1103 |
+
n01796340: ptarmigan
|
1104 |
+
n01797886: ruffed_grouse
|
1105 |
+
n01798484: prairie_chicken
|
1106 |
+
n01806143: peacock
|
1107 |
+
n01806567: quail
|
1108 |
+
n01807496: partridge
|
1109 |
+
n01817953: African_grey
|
1110 |
+
n01818515: macaw
|
1111 |
+
n01819313: sulphur-crested_cockatoo
|
1112 |
+
n01820546: lorikeet
|
1113 |
+
n01824575: coucal
|
1114 |
+
n01828970: bee_eater
|
1115 |
+
n01829413: hornbill
|
1116 |
+
n01833805: hummingbird
|
1117 |
+
n01843065: jacamar
|
1118 |
+
n01843383: toucan
|
1119 |
+
n01847000: drake
|
1120 |
+
n01855032: red-breasted_merganser
|
1121 |
+
n01855672: goose
|
1122 |
+
n01860187: black_swan
|
1123 |
+
n01871265: tusker
|
1124 |
+
n01872401: echidna
|
1125 |
+
n01873310: platypus
|
1126 |
+
n01877812: wallaby
|
1127 |
+
n01882714: koala
|
1128 |
+
n01883070: wombat
|
1129 |
+
n01910747: jellyfish
|
1130 |
+
n01914609: sea_anemone
|
1131 |
+
n01917289: brain_coral
|
1132 |
+
n01924916: flatworm
|
1133 |
+
n01930112: nematode
|
1134 |
+
n01943899: conch
|
1135 |
+
n01944390: snail
|
1136 |
+
n01945685: slug
|
1137 |
+
n01950731: sea_slug
|
1138 |
+
n01955084: chiton
|
1139 |
+
n01968897: chambered_nautilus
|
1140 |
+
n01978287: Dungeness_crab
|
1141 |
+
n01978455: rock_crab
|
1142 |
+
n01980166: fiddler_crab
|
1143 |
+
n01981276: king_crab
|
1144 |
+
n01983481: American_lobster
|
1145 |
+
n01984695: spiny_lobster
|
1146 |
+
n01985128: crayfish
|
1147 |
+
n01986214: hermit_crab
|
1148 |
+
n01990800: isopod
|
1149 |
+
n02002556: white_stork
|
1150 |
+
n02002724: black_stork
|
1151 |
+
n02006656: spoonbill
|
1152 |
+
n02007558: flamingo
|
1153 |
+
n02009229: little_blue_heron
|
1154 |
+
n02009912: American_egret
|
1155 |
+
n02011460: bittern
|
1156 |
+
n02012849: crane_(bird)
|
1157 |
+
n02013706: limpkin
|
1158 |
+
n02017213: European_gallinule
|
1159 |
+
n02018207: American_coot
|
1160 |
+
n02018795: bustard
|
1161 |
+
n02025239: ruddy_turnstone
|
1162 |
+
n02027492: red-backed_sandpiper
|
1163 |
+
n02028035: redshank
|
1164 |
+
n02033041: dowitcher
|
1165 |
+
n02037110: oystercatcher
|
1166 |
+
n02051845: pelican
|
1167 |
+
n02056570: king_penguin
|
1168 |
+
n02058221: albatross
|
1169 |
+
n02066245: grey_whale
|
1170 |
+
n02071294: killer_whale
|
1171 |
+
n02074367: dugong
|
1172 |
+
n02077923: sea_lion
|
1173 |
+
n02085620: Chihuahua
|
1174 |
+
n02085782: Japanese_spaniel
|
1175 |
+
n02085936: Maltese_dog
|
1176 |
+
n02086079: Pekinese
|
1177 |
+
n02086240: Shih-Tzu
|
1178 |
+
n02086646: Blenheim_spaniel
|
1179 |
+
n02086910: papillon
|
1180 |
+
n02087046: toy_terrier
|
1181 |
+
n02087394: Rhodesian_ridgeback
|
1182 |
+
n02088094: Afghan_hound
|
1183 |
+
n02088238: basset
|
1184 |
+
n02088364: beagle
|
1185 |
+
n02088466: bloodhound
|
1186 |
+
n02088632: bluetick
|
1187 |
+
n02089078: black-and-tan_coonhound
|
1188 |
+
n02089867: Walker_hound
|
1189 |
+
n02089973: English_foxhound
|
1190 |
+
n02090379: redbone
|
1191 |
+
n02090622: borzoi
|
1192 |
+
n02090721: Irish_wolfhound
|
1193 |
+
n02091032: Italian_greyhound
|
1194 |
+
n02091134: whippet
|
1195 |
+
n02091244: Ibizan_hound
|
1196 |
+
n02091467: Norwegian_elkhound
|
1197 |
+
n02091635: otterhound
|
1198 |
+
n02091831: Saluki
|
1199 |
+
n02092002: Scottish_deerhound
|
1200 |
+
n02092339: Weimaraner
|
1201 |
+
n02093256: Staffordshire_bullterrier
|
1202 |
+
n02093428: American_Staffordshire_terrier
|
1203 |
+
n02093647: Bedlington_terrier
|
1204 |
+
n02093754: Border_terrier
|
1205 |
+
n02093859: Kerry_blue_terrier
|
1206 |
+
n02093991: Irish_terrier
|
1207 |
+
n02094114: Norfolk_terrier
|
1208 |
+
n02094258: Norwich_terrier
|
1209 |
+
n02094433: Yorkshire_terrier
|
1210 |
+
n02095314: wire-haired_fox_terrier
|
1211 |
+
n02095570: Lakeland_terrier
|
1212 |
+
n02095889: Sealyham_terrier
|
1213 |
+
n02096051: Airedale
|
1214 |
+
n02096177: cairn
|
1215 |
+
n02096294: Australian_terrier
|
1216 |
+
n02096437: Dandie_Dinmont
|
1217 |
+
n02096585: Boston_bull
|
1218 |
+
n02097047: miniature_schnauzer
|
1219 |
+
n02097130: giant_schnauzer
|
1220 |
+
n02097209: standard_schnauzer
|
1221 |
+
n02097298: Scotch_terrier
|
1222 |
+
n02097474: Tibetan_terrier
|
1223 |
+
n02097658: silky_terrier
|
1224 |
+
n02098105: soft-coated_wheaten_terrier
|
1225 |
+
n02098286: West_Highland_white_terrier
|
1226 |
+
n02098413: Lhasa
|
1227 |
+
n02099267: flat-coated_retriever
|
1228 |
+
n02099429: curly-coated_retriever
|
1229 |
+
n02099601: golden_retriever
|
1230 |
+
n02099712: Labrador_retriever
|
1231 |
+
n02099849: Chesapeake_Bay_retriever
|
1232 |
+
n02100236: German_short-haired_pointer
|
1233 |
+
n02100583: vizsla
|
1234 |
+
n02100735: English_setter
|
1235 |
+
n02100877: Irish_setter
|
1236 |
+
n02101006: Gordon_setter
|
1237 |
+
n02101388: Brittany_spaniel
|
1238 |
+
n02101556: clumber
|
1239 |
+
n02102040: English_springer
|
1240 |
+
n02102177: Welsh_springer_spaniel
|
1241 |
+
n02102318: cocker_spaniel
|
1242 |
+
n02102480: Sussex_spaniel
|
1243 |
+
n02102973: Irish_water_spaniel
|
1244 |
+
n02104029: kuvasz
|
1245 |
+
n02104365: schipperke
|
1246 |
+
n02105056: groenendael
|
1247 |
+
n02105162: malinois
|
1248 |
+
n02105251: briard
|
1249 |
+
n02105412: kelpie
|
1250 |
+
n02105505: komondor
|
1251 |
+
n02105641: Old_English_sheepdog
|
1252 |
+
n02105855: Shetland_sheepdog
|
1253 |
+
n02106030: collie
|
1254 |
+
n02106166: Border_collie
|
1255 |
+
n02106382: Bouvier_des_Flandres
|
1256 |
+
n02106550: Rottweiler
|
1257 |
+
n02106662: German_shepherd
|
1258 |
+
n02107142: Doberman
|
1259 |
+
n02107312: miniature_pinscher
|
1260 |
+
n02107574: Greater_Swiss_Mountain_dog
|
1261 |
+
n02107683: Bernese_mountain_dog
|
1262 |
+
n02107908: Appenzeller
|
1263 |
+
n02108000: EntleBucher
|
1264 |
+
n02108089: boxer
|
1265 |
+
n02108422: bull_mastiff
|
1266 |
+
n02108551: Tibetan_mastiff
|
1267 |
+
n02108915: French_bulldog
|
1268 |
+
n02109047: Great_Dane
|
1269 |
+
n02109525: Saint_Bernard
|
1270 |
+
n02109961: Eskimo_dog
|
1271 |
+
n02110063: malamute
|
1272 |
+
n02110185: Siberian_husky
|
1273 |
+
n02110341: dalmatian
|
1274 |
+
n02110627: affenpinscher
|
1275 |
+
n02110806: basenji
|
1276 |
+
n02110958: pug
|
1277 |
+
n02111129: Leonberg
|
1278 |
+
n02111277: Newfoundland
|
1279 |
+
n02111500: Great_Pyrenees
|
1280 |
+
n02111889: Samoyed
|
1281 |
+
n02112018: Pomeranian
|
1282 |
+
n02112137: chow
|
1283 |
+
n02112350: keeshond
|
1284 |
+
n02112706: Brabancon_griffon
|
1285 |
+
n02113023: Pembroke
|
1286 |
+
n02113186: Cardigan
|
1287 |
+
n02113624: toy_poodle
|
1288 |
+
n02113712: miniature_poodle
|
1289 |
+
n02113799: standard_poodle
|
1290 |
+
n02113978: Mexican_hairless
|
1291 |
+
n02114367: timber_wolf
|
1292 |
+
n02114548: white_wolf
|
1293 |
+
n02114712: red_wolf
|
1294 |
+
n02114855: coyote
|
1295 |
+
n02115641: dingo
|
1296 |
+
n02115913: dhole
|
1297 |
+
n02116738: African_hunting_dog
|
1298 |
+
n02117135: hyena
|
1299 |
+
n02119022: red_fox
|
1300 |
+
n02119789: kit_fox
|
1301 |
+
n02120079: Arctic_fox
|
1302 |
+
n02120505: grey_fox
|
1303 |
+
n02123045: tabby
|
1304 |
+
n02123159: tiger_cat
|
1305 |
+
n02123394: Persian_cat
|
1306 |
+
n02123597: Siamese_cat
|
1307 |
+
n02124075: Egyptian_cat
|
1308 |
+
n02125311: cougar
|
1309 |
+
n02127052: lynx
|
1310 |
+
n02128385: leopard
|
1311 |
+
n02128757: snow_leopard
|
1312 |
+
n02128925: jaguar
|
1313 |
+
n02129165: lion
|
1314 |
+
n02129604: tiger
|
1315 |
+
n02130308: cheetah
|
1316 |
+
n02132136: brown_bear
|
1317 |
+
n02133161: American_black_bear
|
1318 |
+
n02134084: ice_bear
|
1319 |
+
n02134418: sloth_bear
|
1320 |
+
n02137549: mongoose
|
1321 |
+
n02138441: meerkat
|
1322 |
+
n02165105: tiger_beetle
|
1323 |
+
n02165456: ladybug
|
1324 |
+
n02167151: ground_beetle
|
1325 |
+
n02168699: long-horned_beetle
|
1326 |
+
n02169497: leaf_beetle
|
1327 |
+
n02172182: dung_beetle
|
1328 |
+
n02174001: rhinoceros_beetle
|
1329 |
+
n02177972: weevil
|
1330 |
+
n02190166: fly
|
1331 |
+
n02206856: bee
|
1332 |
+
n02219486: ant
|
1333 |
+
n02226429: grasshopper
|
1334 |
+
n02229544: cricket
|
1335 |
+
n02231487: walking_stick
|
1336 |
+
n02233338: cockroach
|
1337 |
+
n02236044: mantis
|
1338 |
+
n02256656: cicada
|
1339 |
+
n02259212: leafhopper
|
1340 |
+
n02264363: lacewing
|
1341 |
+
n02268443: dragonfly
|
1342 |
+
n02268853: damselfly
|
1343 |
+
n02276258: admiral
|
1344 |
+
n02277742: ringlet
|
1345 |
+
n02279972: monarch
|
1346 |
+
n02280649: cabbage_butterfly
|
1347 |
+
n02281406: sulphur_butterfly
|
1348 |
+
n02281787: lycaenid
|
1349 |
+
n02317335: starfish
|
1350 |
+
n02319095: sea_urchin
|
1351 |
+
n02321529: sea_cucumber
|
1352 |
+
n02325366: wood_rabbit
|
1353 |
+
n02326432: hare
|
1354 |
+
n02328150: Angora
|
1355 |
+
n02342885: hamster
|
1356 |
+
n02346627: porcupine
|
1357 |
+
n02356798: fox_squirrel
|
1358 |
+
n02361337: marmot
|
1359 |
+
n02363005: beaver
|
1360 |
+
n02364673: guinea_pig
|
1361 |
+
n02389026: sorrel
|
1362 |
+
n02391049: zebra
|
1363 |
+
n02395406: hog
|
1364 |
+
n02396427: wild_boar
|
1365 |
+
n02397096: warthog
|
1366 |
+
n02398521: hippopotamus
|
1367 |
+
n02403003: ox
|
1368 |
+
n02408429: water_buffalo
|
1369 |
+
n02410509: bison
|
1370 |
+
n02412080: ram
|
1371 |
+
n02415577: bighorn
|
1372 |
+
n02417914: ibex
|
1373 |
+
n02422106: hartebeest
|
1374 |
+
n02422699: impala
|
1375 |
+
n02423022: gazelle
|
1376 |
+
n02437312: Arabian_camel
|
1377 |
+
n02437616: llama
|
1378 |
+
n02441942: weasel
|
1379 |
+
n02442845: mink
|
1380 |
+
n02443114: polecat
|
1381 |
+
n02443484: black-footed_ferret
|
1382 |
+
n02444819: otter
|
1383 |
+
n02445715: skunk
|
1384 |
+
n02447366: badger
|
1385 |
+
n02454379: armadillo
|
1386 |
+
n02457408: three-toed_sloth
|
1387 |
+
n02480495: orangutan
|
1388 |
+
n02480855: gorilla
|
1389 |
+
n02481823: chimpanzee
|
1390 |
+
n02483362: gibbon
|
1391 |
+
n02483708: siamang
|
1392 |
+
n02484975: guenon
|
1393 |
+
n02486261: patas
|
1394 |
+
n02486410: baboon
|
1395 |
+
n02487347: macaque
|
1396 |
+
n02488291: langur
|
1397 |
+
n02488702: colobus
|
1398 |
+
n02489166: proboscis_monkey
|
1399 |
+
n02490219: marmoset
|
1400 |
+
n02492035: capuchin
|
1401 |
+
n02492660: howler_monkey
|
1402 |
+
n02493509: titi
|
1403 |
+
n02493793: spider_monkey
|
1404 |
+
n02494079: squirrel_monkey
|
1405 |
+
n02497673: Madagascar_cat
|
1406 |
+
n02500267: indri
|
1407 |
+
n02504013: Indian_elephant
|
1408 |
+
n02504458: African_elephant
|
1409 |
+
n02509815: lesser_panda
|
1410 |
+
n02510455: giant_panda
|
1411 |
+
n02514041: barracouta
|
1412 |
+
n02526121: eel
|
1413 |
+
n02536864: coho
|
1414 |
+
n02606052: rock_beauty
|
1415 |
+
n02607072: anemone_fish
|
1416 |
+
n02640242: sturgeon
|
1417 |
+
n02641379: gar
|
1418 |
+
n02643566: lionfish
|
1419 |
+
n02655020: puffer
|
1420 |
+
n02666196: abacus
|
1421 |
+
n02667093: abaya
|
1422 |
+
n02669723: academic_gown
|
1423 |
+
n02672831: accordion
|
1424 |
+
n02676566: acoustic_guitar
|
1425 |
+
n02687172: aircraft_carrier
|
1426 |
+
n02690373: airliner
|
1427 |
+
n02692877: airship
|
1428 |
+
n02699494: altar
|
1429 |
+
n02701002: ambulance
|
1430 |
+
n02704792: amphibian
|
1431 |
+
n02708093: analog_clock
|
1432 |
+
n02727426: apiary
|
1433 |
+
n02730930: apron
|
1434 |
+
n02747177: ashcan
|
1435 |
+
n02749479: assault_rifle
|
1436 |
+
n02769748: backpack
|
1437 |
+
n02776631: bakery
|
1438 |
+
n02777292: balance_beam
|
1439 |
+
n02782093: balloon
|
1440 |
+
n02783161: ballpoint
|
1441 |
+
n02786058: Band_Aid
|
1442 |
+
n02787622: banjo
|
1443 |
+
n02788148: bannister
|
1444 |
+
n02790996: barbell
|
1445 |
+
n02791124: barber_chair
|
1446 |
+
n02791270: barbershop
|
1447 |
+
n02793495: barn
|
1448 |
+
n02794156: barometer
|
1449 |
+
n02795169: barrel
|
1450 |
+
n02797295: barrow
|
1451 |
+
n02799071: baseball
|
1452 |
+
n02802426: basketball
|
1453 |
+
n02804414: bassinet
|
1454 |
+
n02804610: bassoon
|
1455 |
+
n02807133: bathing_cap
|
1456 |
+
n02808304: bath_towel
|
1457 |
+
n02808440: bathtub
|
1458 |
+
n02814533: beach_wagon
|
1459 |
+
n02814860: beacon
|
1460 |
+
n02815834: beaker
|
1461 |
+
n02817516: bearskin
|
1462 |
+
n02823428: beer_bottle
|
1463 |
+
n02823750: beer_glass
|
1464 |
+
n02825657: bell_cote
|
1465 |
+
n02834397: bib
|
1466 |
+
n02835271: bicycle-built-for-two
|
1467 |
+
n02837789: bikini
|
1468 |
+
n02840245: binder
|
1469 |
+
n02841315: binoculars
|
1470 |
+
n02843684: birdhouse
|
1471 |
+
n02859443: boathouse
|
1472 |
+
n02860847: bobsled
|
1473 |
+
n02865351: bolo_tie
|
1474 |
+
n02869837: bonnet
|
1475 |
+
n02870880: bookcase
|
1476 |
+
n02871525: bookshop
|
1477 |
+
n02877765: bottlecap
|
1478 |
+
n02879718: bow
|
1479 |
+
n02883205: bow_tie
|
1480 |
+
n02892201: brass
|
1481 |
+
n02892767: brassiere
|
1482 |
+
n02894605: breakwater
|
1483 |
+
n02895154: breastplate
|
1484 |
+
n02906734: broom
|
1485 |
+
n02909870: bucket
|
1486 |
+
n02910353: buckle
|
1487 |
+
n02916936: bulletproof_vest
|
1488 |
+
n02917067: bullet_train
|
1489 |
+
n02927161: butcher_shop
|
1490 |
+
n02930766: cab
|
1491 |
+
n02939185: caldron
|
1492 |
+
n02948072: candle
|
1493 |
+
n02950826: cannon
|
1494 |
+
n02951358: canoe
|
1495 |
+
n02951585: can_opener
|
1496 |
+
n02963159: cardigan
|
1497 |
+
n02965783: car_mirror
|
1498 |
+
n02966193: carousel
|
1499 |
+
n02966687: carpenter's_kit
|
1500 |
+
n02971356: carton
|
1501 |
+
n02974003: car_wheel
|
1502 |
+
n02977058: cash_machine
|
1503 |
+
n02978881: cassette
|
1504 |
+
n02979186: cassette_player
|
1505 |
+
n02980441: castle
|
1506 |
+
n02981792: catamaran
|
1507 |
+
n02988304: CD_player
|
1508 |
+
n02992211: cello
|
1509 |
+
n02992529: cellular_telephone
|
1510 |
+
n02999410: chain
|
1511 |
+
n03000134: chainlink_fence
|
1512 |
+
n03000247: chain_mail
|
1513 |
+
n03000684: chain_saw
|
1514 |
+
n03014705: chest
|
1515 |
+
n03016953: chiffonier
|
1516 |
+
n03017168: chime
|
1517 |
+
n03018349: china_cabinet
|
1518 |
+
n03026506: Christmas_stocking
|
1519 |
+
n03028079: church
|
1520 |
+
n03032252: cinema
|
1521 |
+
n03041632: cleaver
|
1522 |
+
n03042490: cliff_dwelling
|
1523 |
+
n03045698: cloak
|
1524 |
+
n03047690: clog
|
1525 |
+
n03062245: cocktail_shaker
|
1526 |
+
n03063599: coffee_mug
|
1527 |
+
n03063689: coffeepot
|
1528 |
+
n03065424: coil
|
1529 |
+
n03075370: combination_lock
|
1530 |
+
n03085013: computer_keyboard
|
1531 |
+
n03089624: confectionery
|
1532 |
+
n03095699: container_ship
|
1533 |
+
n03100240: convertible
|
1534 |
+
n03109150: corkscrew
|
1535 |
+
n03110669: cornet
|
1536 |
+
n03124043: cowboy_boot
|
1537 |
+
n03124170: cowboy_hat
|
1538 |
+
n03125729: cradle
|
1539 |
+
n03126707: crane_(machine)
|
1540 |
+
n03127747: crash_helmet
|
1541 |
+
n03127925: crate
|
1542 |
+
n03131574: crib
|
1543 |
+
n03133878: Crock_Pot
|
1544 |
+
n03134739: croquet_ball
|
1545 |
+
n03141823: crutch
|
1546 |
+
n03146219: cuirass
|
1547 |
+
n03160309: dam
|
1548 |
+
n03179701: desk
|
1549 |
+
n03180011: desktop_computer
|
1550 |
+
n03187595: dial_telephone
|
1551 |
+
n03188531: diaper
|
1552 |
+
n03196217: digital_clock
|
1553 |
+
n03197337: digital_watch
|
1554 |
+
n03201208: dining_table
|
1555 |
+
n03207743: dishrag
|
1556 |
+
n03207941: dishwasher
|
1557 |
+
n03208938: disk_brake
|
1558 |
+
n03216828: dock
|
1559 |
+
n03218198: dogsled
|
1560 |
+
n03220513: dome
|
1561 |
+
n03223299: doormat
|
1562 |
+
n03240683: drilling_platform
|
1563 |
+
n03249569: drum
|
1564 |
+
n03250847: drumstick
|
1565 |
+
n03255030: dumbbell
|
1566 |
+
n03259280: Dutch_oven
|
1567 |
+
n03271574: electric_fan
|
1568 |
+
n03272010: electric_guitar
|
1569 |
+
n03272562: electric_locomotive
|
1570 |
+
n03290653: entertainment_center
|
1571 |
+
n03291819: envelope
|
1572 |
+
n03297495: espresso_maker
|
1573 |
+
n03314780: face_powder
|
1574 |
+
n03325584: feather_boa
|
1575 |
+
n03337140: file
|
1576 |
+
n03344393: fireboat
|
1577 |
+
n03345487: fire_engine
|
1578 |
+
n03347037: fire_screen
|
1579 |
+
n03355925: flagpole
|
1580 |
+
n03372029: flute
|
1581 |
+
n03376595: folding_chair
|
1582 |
+
n03379051: football_helmet
|
1583 |
+
n03384352: forklift
|
1584 |
+
n03388043: fountain
|
1585 |
+
n03388183: fountain_pen
|
1586 |
+
n03388549: four-poster
|
1587 |
+
n03393912: freight_car
|
1588 |
+
n03394916: French_horn
|
1589 |
+
n03400231: frying_pan
|
1590 |
+
n03404251: fur_coat
|
1591 |
+
n03417042: garbage_truck
|
1592 |
+
n03424325: gasmask
|
1593 |
+
n03425413: gas_pump
|
1594 |
+
n03443371: goblet
|
1595 |
+
n03444034: go-kart
|
1596 |
+
n03445777: golf_ball
|
1597 |
+
n03445924: golfcart
|
1598 |
+
n03447447: gondola
|
1599 |
+
n03447721: gong
|
1600 |
+
n03450230: gown
|
1601 |
+
n03452741: grand_piano
|
1602 |
+
n03457902: greenhouse
|
1603 |
+
n03459775: grille
|
1604 |
+
n03461385: grocery_store
|
1605 |
+
n03467068: guillotine
|
1606 |
+
n03476684: hair_slide
|
1607 |
+
n03476991: hair_spray
|
1608 |
+
n03478589: half_track
|
1609 |
+
n03481172: hammer
|
1610 |
+
n03482405: hamper
|
1611 |
+
n03483316: hand_blower
|
1612 |
+
n03485407: hand-held_computer
|
1613 |
+
n03485794: handkerchief
|
1614 |
+
n03492542: hard_disc
|
1615 |
+
n03494278: harmonica
|
1616 |
+
n03495258: harp
|
1617 |
+
n03496892: harvester
|
1618 |
+
n03498962: hatchet
|
1619 |
+
n03527444: holster
|
1620 |
+
n03529860: home_theater
|
1621 |
+
n03530642: honeycomb
|
1622 |
+
n03532672: hook
|
1623 |
+
n03534580: hoopskirt
|
1624 |
+
n03535780: horizontal_bar
|
1625 |
+
n03538406: horse_cart
|
1626 |
+
n03544143: hourglass
|
1627 |
+
n03584254: iPod
|
1628 |
+
n03584829: iron
|
1629 |
+
n03590841: jack-o'-lantern
|
1630 |
+
n03594734: jean
|
1631 |
+
n03594945: jeep
|
1632 |
+
n03595614: jersey
|
1633 |
+
n03598930: jigsaw_puzzle
|
1634 |
+
n03599486: jinrikisha
|
1635 |
+
n03602883: joystick
|
1636 |
+
n03617480: kimono
|
1637 |
+
n03623198: knee_pad
|
1638 |
+
n03627232: knot
|
1639 |
+
n03630383: lab_coat
|
1640 |
+
n03633091: ladle
|
1641 |
+
n03637318: lampshade
|
1642 |
+
n03642806: laptop
|
1643 |
+
n03649909: lawn_mower
|
1644 |
+
n03657121: lens_cap
|
1645 |
+
n03658185: letter_opener
|
1646 |
+
n03661043: library
|
1647 |
+
n03662601: lifeboat
|
1648 |
+
n03666591: lighter
|
1649 |
+
n03670208: limousine
|
1650 |
+
n03673027: liner
|
1651 |
+
n03676483: lipstick
|
1652 |
+
n03680355: Loafer
|
1653 |
+
n03690938: lotion
|
1654 |
+
n03691459: loudspeaker
|
1655 |
+
n03692522: loupe
|
1656 |
+
n03697007: lumbermill
|
1657 |
+
n03706229: magnetic_compass
|
1658 |
+
n03709823: mailbag
|
1659 |
+
n03710193: mailbox
|
1660 |
+
n03710637: maillot_(tights)
|
1661 |
+
n03710721: maillot_(tank_suit)
|
1662 |
+
n03717622: manhole_cover
|
1663 |
+
n03720891: maraca
|
1664 |
+
n03721384: marimba
|
1665 |
+
n03724870: mask
|
1666 |
+
n03729826: matchstick
|
1667 |
+
n03733131: maypole
|
1668 |
+
n03733281: maze
|
1669 |
+
n03733805: measuring_cup
|
1670 |
+
n03742115: medicine_chest
|
1671 |
+
n03743016: megalith
|
1672 |
+
n03759954: microphone
|
1673 |
+
n03761084: microwave
|
1674 |
+
n03763968: military_uniform
|
1675 |
+
n03764736: milk_can
|
1676 |
+
n03769881: minibus
|
1677 |
+
n03770439: miniskirt
|
1678 |
+
n03770679: minivan
|
1679 |
+
n03773504: missile
|
1680 |
+
n03775071: mitten
|
1681 |
+
n03775546: mixing_bowl
|
1682 |
+
n03776460: mobile_home
|
1683 |
+
n03777568: Model_T
|
1684 |
+
n03777754: modem
|
1685 |
+
n03781244: monastery
|
1686 |
+
n03782006: monitor
|
1687 |
+
n03785016: moped
|
1688 |
+
n03786901: mortar
|
1689 |
+
n03787032: mortarboard
|
1690 |
+
n03788195: mosque
|
1691 |
+
n03788365: mosquito_net
|
1692 |
+
n03791053: motor_scooter
|
1693 |
+
n03792782: mountain_bike
|
1694 |
+
n03792972: mountain_tent
|
1695 |
+
n03793489: mouse
|
1696 |
+
n03794056: mousetrap
|
1697 |
+
n03796401: moving_van
|
1698 |
+
n03803284: muzzle
|
1699 |
+
n03804744: nail
|
1700 |
+
n03814639: neck_brace
|
1701 |
+
n03814906: necklace
|
1702 |
+
n03825788: nipple
|
1703 |
+
n03832673: notebook
|
1704 |
+
n03837869: obelisk
|
1705 |
+
n03838899: oboe
|
1706 |
+
n03840681: ocarina
|
1707 |
+
n03841143: odometer
|
1708 |
+
n03843555: oil_filter
|
1709 |
+
n03854065: organ
|
1710 |
+
n03857828: oscilloscope
|
1711 |
+
n03866082: overskirt
|
1712 |
+
n03868242: oxcart
|
1713 |
+
n03868863: oxygen_mask
|
1714 |
+
n03871628: packet
|
1715 |
+
n03873416: paddle
|
1716 |
+
n03874293: paddlewheel
|
1717 |
+
n03874599: padlock
|
1718 |
+
n03876231: paintbrush
|
1719 |
+
n03877472: pajama
|
1720 |
+
n03877845: palace
|
1721 |
+
n03884397: panpipe
|
1722 |
+
n03887697: paper_towel
|
1723 |
+
n03888257: parachute
|
1724 |
+
n03888605: parallel_bars
|
1725 |
+
n03891251: park_bench
|
1726 |
+
n03891332: parking_meter
|
1727 |
+
n03895866: passenger_car
|
1728 |
+
n03899768: patio
|
1729 |
+
n03902125: pay-phone
|
1730 |
+
n03903868: pedestal
|
1731 |
+
n03908618: pencil_box
|
1732 |
+
n03908714: pencil_sharpener
|
1733 |
+
n03916031: perfume
|
1734 |
+
n03920288: Petri_dish
|
1735 |
+
n03924679: photocopier
|
1736 |
+
n03929660: pick
|
1737 |
+
n03929855: pickelhaube
|
1738 |
+
n03930313: picket_fence
|
1739 |
+
n03930630: pickup
|
1740 |
+
n03933933: pier
|
1741 |
+
n03935335: piggy_bank
|
1742 |
+
n03937543: pill_bottle
|
1743 |
+
n03938244: pillow
|
1744 |
+
n03942813: ping-pong_ball
|
1745 |
+
n03944341: pinwheel
|
1746 |
+
n03947888: pirate
|
1747 |
+
n03950228: pitcher
|
1748 |
+
n03954731: plane
|
1749 |
+
n03956157: planetarium
|
1750 |
+
n03958227: plastic_bag
|
1751 |
+
n03961711: plate_rack
|
1752 |
+
n03967562: plow
|
1753 |
+
n03970156: plunger
|
1754 |
+
n03976467: Polaroid_camera
|
1755 |
+
n03976657: pole
|
1756 |
+
n03977966: police_van
|
1757 |
+
n03980874: poncho
|
1758 |
+
n03982430: pool_table
|
1759 |
+
n03983396: pop_bottle
|
1760 |
+
n03991062: pot
|
1761 |
+
n03992509: potter's_wheel
|
1762 |
+
n03995372: power_drill
|
1763 |
+
n03998194: prayer_rug
|
1764 |
+
n04004767: printer
|
1765 |
+
n04005630: prison
|
1766 |
+
n04008634: projectile
|
1767 |
+
n04009552: projector
|
1768 |
+
n04019541: puck
|
1769 |
+
n04023962: punching_bag
|
1770 |
+
n04026417: purse
|
1771 |
+
n04033901: quill
|
1772 |
+
n04033995: quilt
|
1773 |
+
n04037443: racer
|
1774 |
+
n04039381: racket
|
1775 |
+
n04040759: radiator
|
1776 |
+
n04041544: radio
|
1777 |
+
n04044716: radio_telescope
|
1778 |
+
n04049303: rain_barrel
|
1779 |
+
n04065272: recreational_vehicle
|
1780 |
+
n04067472: reel
|
1781 |
+
n04069434: reflex_camera
|
1782 |
+
n04070727: refrigerator
|
1783 |
+
n04074963: remote_control
|
1784 |
+
n04081281: restaurant
|
1785 |
+
n04086273: revolver
|
1786 |
+
n04090263: rifle
|
1787 |
+
n04099969: rocking_chair
|
1788 |
+
n04111531: rotisserie
|
1789 |
+
n04116512: rubber_eraser
|
1790 |
+
n04118538: rugby_ball
|
1791 |
+
n04118776: rule
|
1792 |
+
n04120489: running_shoe
|
1793 |
+
n04125021: safe
|
1794 |
+
n04127249: safety_pin
|
1795 |
+
n04131690: saltshaker
|
1796 |
+
n04133789: sandal
|
1797 |
+
n04136333: sarong
|
1798 |
+
n04141076: sax
|
1799 |
+
n04141327: scabbard
|
1800 |
+
n04141975: scale
|
1801 |
+
n04146614: school_bus
|
1802 |
+
n04147183: schooner
|
1803 |
+
n04149813: scoreboard
|
1804 |
+
n04152593: screen
|
1805 |
+
n04153751: screw
|
1806 |
+
n04154565: screwdriver
|
1807 |
+
n04162706: seat_belt
|
1808 |
+
n04179913: sewing_machine
|
1809 |
+
n04192698: shield
|
1810 |
+
n04200800: shoe_shop
|
1811 |
+
n04201297: shoji
|
1812 |
+
n04204238: shopping_basket
|
1813 |
+
n04204347: shopping_cart
|
1814 |
+
n04208210: shovel
|
1815 |
+
n04209133: shower_cap
|
1816 |
+
n04209239: shower_curtain
|
1817 |
+
n04228054: ski
|
1818 |
+
n04229816: ski_mask
|
1819 |
+
n04235860: sleeping_bag
|
1820 |
+
n04238763: slide_rule
|
1821 |
+
n04239074: sliding_door
|
1822 |
+
n04243546: slot
|
1823 |
+
n04251144: snorkel
|
1824 |
+
n04252077: snowmobile
|
1825 |
+
n04252225: snowplow
|
1826 |
+
n04254120: soap_dispenser
|
1827 |
+
n04254680: soccer_ball
|
1828 |
+
n04254777: sock
|
1829 |
+
n04258138: solar_dish
|
1830 |
+
n04259630: sombrero
|
1831 |
+
n04263257: soup_bowl
|
1832 |
+
n04264628: space_bar
|
1833 |
+
n04265275: space_heater
|
1834 |
+
n04266014: space_shuttle
|
1835 |
+
n04270147: spatula
|
1836 |
+
n04273569: speedboat
|
1837 |
+
n04275548: spider_web
|
1838 |
+
n04277352: spindle
|
1839 |
+
n04285008: sports_car
|
1840 |
+
n04286575: spotlight
|
1841 |
+
n04296562: stage
|
1842 |
+
n04310018: steam_locomotive
|
1843 |
+
n04311004: steel_arch_bridge
|
1844 |
+
n04311174: steel_drum
|
1845 |
+
n04317175: stethoscope
|
1846 |
+
n04325704: stole
|
1847 |
+
n04326547: stone_wall
|
1848 |
+
n04328186: stopwatch
|
1849 |
+
n04330267: stove
|
1850 |
+
n04332243: strainer
|
1851 |
+
n04335435: streetcar
|
1852 |
+
n04336792: stretcher
|
1853 |
+
n04344873: studio_couch
|
1854 |
+
n04346328: stupa
|
1855 |
+
n04347754: submarine
|
1856 |
+
n04350905: suit
|
1857 |
+
n04355338: sundial
|
1858 |
+
n04355933: sunglass
|
1859 |
+
n04356056: sunglasses
|
1860 |
+
n04357314: sunscreen
|
1861 |
+
n04366367: suspension_bridge
|
1862 |
+
n04367480: swab
|
1863 |
+
n04370456: sweatshirt
|
1864 |
+
n04371430: swimming_trunks
|
1865 |
+
n04371774: swing
|
1866 |
+
n04372370: switch
|
1867 |
+
n04376876: syringe
|
1868 |
+
n04380533: table_lamp
|
1869 |
+
n04389033: tank
|
1870 |
+
n04392985: tape_player
|
1871 |
+
n04398044: teapot
|
1872 |
+
n04399382: teddy
|
1873 |
+
n04404412: television
|
1874 |
+
n04409515: tennis_ball
|
1875 |
+
n04417672: thatch
|
1876 |
+
n04418357: theater_curtain
|
1877 |
+
n04423845: thimble
|
1878 |
+
n04428191: thresher
|
1879 |
+
n04429376: throne
|
1880 |
+
n04435653: tile_roof
|
1881 |
+
n04442312: toaster
|
1882 |
+
n04443257: tobacco_shop
|
1883 |
+
n04447861: toilet_seat
|
1884 |
+
n04456115: torch
|
1885 |
+
n04458633: totem_pole
|
1886 |
+
n04461696: tow_truck
|
1887 |
+
n04462240: toyshop
|
1888 |
+
n04465501: tractor
|
1889 |
+
n04467665: trailer_truck
|
1890 |
+
n04476259: tray
|
1891 |
+
n04479046: trench_coat
|
1892 |
+
n04482393: tricycle
|
1893 |
+
n04483307: trimaran
|
1894 |
+
n04485082: tripod
|
1895 |
+
n04486054: triumphal_arch
|
1896 |
+
n04487081: trolleybus
|
1897 |
+
n04487394: trombone
|
1898 |
+
n04493381: tub
|
1899 |
+
n04501370: turnstile
|
1900 |
+
n04505470: typewriter_keyboard
|
1901 |
+
n04507155: umbrella
|
1902 |
+
n04509417: unicycle
|
1903 |
+
n04515003: upright
|
1904 |
+
n04517823: vacuum
|
1905 |
+
n04522168: vase
|
1906 |
+
n04523525: vault
|
1907 |
+
n04525038: velvet
|
1908 |
+
n04525305: vending_machine
|
1909 |
+
n04532106: vestment
|
1910 |
+
n04532670: viaduct
|
1911 |
+
n04536866: violin
|
1912 |
+
n04540053: volleyball
|
1913 |
+
n04542943: waffle_iron
|
1914 |
+
n04548280: wall_clock
|
1915 |
+
n04548362: wallet
|
1916 |
+
n04550184: wardrobe
|
1917 |
+
n04552348: warplane
|
1918 |
+
n04553703: washbasin
|
1919 |
+
n04554684: washer
|
1920 |
+
n04557648: water_bottle
|
1921 |
+
n04560804: water_jug
|
1922 |
+
n04562935: water_tower
|
1923 |
+
n04579145: whiskey_jug
|
1924 |
+
n04579432: whistle
|
1925 |
+
n04584207: wig
|
1926 |
+
n04589890: window_screen
|
1927 |
+
n04590129: window_shade
|
1928 |
+
n04591157: Windsor_tie
|
1929 |
+
n04591713: wine_bottle
|
1930 |
+
n04592741: wing
|
1931 |
+
n04596742: wok
|
1932 |
+
n04597913: wooden_spoon
|
1933 |
+
n04599235: wool
|
1934 |
+
n04604644: worm_fence
|
1935 |
+
n04606251: wreck
|
1936 |
+
n04612504: yawl
|
1937 |
+
n04613696: yurt
|
1938 |
+
n06359193: web_site
|
1939 |
+
n06596364: comic_book
|
1940 |
+
n06785654: crossword_puzzle
|
1941 |
+
n06794110: street_sign
|
1942 |
+
n06874185: traffic_light
|
1943 |
+
n07248320: book_jacket
|
1944 |
+
n07565083: menu
|
1945 |
+
n07579787: plate
|
1946 |
+
n07583066: guacamole
|
1947 |
+
n07584110: consomme
|
1948 |
+
n07590611: hot_pot
|
1949 |
+
n07613480: trifle
|
1950 |
+
n07614500: ice_cream
|
1951 |
+
n07615774: ice_lolly
|
1952 |
+
n07684084: French_loaf
|
1953 |
+
n07693725: bagel
|
1954 |
+
n07695742: pretzel
|
1955 |
+
n07697313: cheeseburger
|
1956 |
+
n07697537: hotdog
|
1957 |
+
n07711569: mashed_potato
|
1958 |
+
n07714571: head_cabbage
|
1959 |
+
n07714990: broccoli
|
1960 |
+
n07715103: cauliflower
|
1961 |
+
n07716358: zucchini
|
1962 |
+
n07716906: spaghetti_squash
|
1963 |
+
n07717410: acorn_squash
|
1964 |
+
n07717556: butternut_squash
|
1965 |
+
n07718472: cucumber
|
1966 |
+
n07718747: artichoke
|
1967 |
+
n07720875: bell_pepper
|
1968 |
+
n07730033: cardoon
|
1969 |
+
n07734744: mushroom
|
1970 |
+
n07742313: Granny_Smith
|
1971 |
+
n07745940: strawberry
|
1972 |
+
n07747607: orange
|
1973 |
+
n07749582: lemon
|
1974 |
+
n07753113: fig
|
1975 |
+
n07753275: pineapple
|
1976 |
+
n07753592: banana
|
1977 |
+
n07754684: jackfruit
|
1978 |
+
n07760859: custard_apple
|
1979 |
+
n07768694: pomegranate
|
1980 |
+
n07802026: hay
|
1981 |
+
n07831146: carbonara
|
1982 |
+
n07836838: chocolate_sauce
|
1983 |
+
n07860988: dough
|
1984 |
+
n07871810: meat_loaf
|
1985 |
+
n07873807: pizza
|
1986 |
+
n07875152: potpie
|
1987 |
+
n07880968: burrito
|
1988 |
+
n07892512: red_wine
|
1989 |
+
n07920052: espresso
|
1990 |
+
n07930864: cup
|
1991 |
+
n07932039: eggnog
|
1992 |
+
n09193705: alp
|
1993 |
+
n09229709: bubble
|
1994 |
+
n09246464: cliff
|
1995 |
+
n09256479: coral_reef
|
1996 |
+
n09288635: geyser
|
1997 |
+
n09332890: lakeside
|
1998 |
+
n09399592: promontory
|
1999 |
+
n09421951: sandbar
|
2000 |
+
n09428293: seashore
|
2001 |
+
n09468604: valley
|
2002 |
+
n09472597: volcano
|
2003 |
+
n09835506: ballplayer
|
2004 |
+
n10148035: groom
|
2005 |
+
n10565667: scuba_diver
|
2006 |
+
n11879895: rapeseed
|
2007 |
+
n11939491: daisy
|
2008 |
+
n12057211: yellow_lady's_slipper
|
2009 |
+
n12144580: corn
|
2010 |
+
n12267677: acorn
|
2011 |
+
n12620546: hip
|
2012 |
+
n12768682: buckeye
|
2013 |
+
n12985857: coral_fungus
|
2014 |
+
n12998815: agaric
|
2015 |
+
n13037406: gyromitra
|
2016 |
+
n13040303: stinkhorn
|
2017 |
+
n13044778: earthstar
|
2018 |
+
n13052670: hen-of-the-woods
|
2019 |
+
n13054560: bolete
|
2020 |
+
n13133613: ear
|
2021 |
+
n15075141: toilet_tissue
|
2022 |
+
|
2023 |
+
|
2024 |
+
# Download script/URL (optional)
|
2025 |
+
download: yolo/data/scripts/get_imagenet.sh
|
ultralytics/datasets/Objects365.yaml
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# Objects365 dataset https://www.objects365.org/ by Megvii
|
3 |
+
# Example usage: yolo train data=Objects365.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── Objects365 ← downloads here (712 GB = 367G data + 345G zips)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/Objects365 # dataset root dir
|
12 |
+
train: images/train # train images (relative to 'path') 1742289 images
|
13 |
+
val: images/val # val images (relative to 'path') 80000 images
|
14 |
+
test: # test images (optional)
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: Person
|
19 |
+
1: Sneakers
|
20 |
+
2: Chair
|
21 |
+
3: Other Shoes
|
22 |
+
4: Hat
|
23 |
+
5: Car
|
24 |
+
6: Lamp
|
25 |
+
7: Glasses
|
26 |
+
8: Bottle
|
27 |
+
9: Desk
|
28 |
+
10: Cup
|
29 |
+
11: Street Lights
|
30 |
+
12: Cabinet/shelf
|
31 |
+
13: Handbag/Satchel
|
32 |
+
14: Bracelet
|
33 |
+
15: Plate
|
34 |
+
16: Picture/Frame
|
35 |
+
17: Helmet
|
36 |
+
18: Book
|
37 |
+
19: Gloves
|
38 |
+
20: Storage box
|
39 |
+
21: Boat
|
40 |
+
22: Leather Shoes
|
41 |
+
23: Flower
|
42 |
+
24: Bench
|
43 |
+
25: Potted Plant
|
44 |
+
26: Bowl/Basin
|
45 |
+
27: Flag
|
46 |
+
28: Pillow
|
47 |
+
29: Boots
|
48 |
+
30: Vase
|
49 |
+
31: Microphone
|
50 |
+
32: Necklace
|
51 |
+
33: Ring
|
52 |
+
34: SUV
|
53 |
+
35: Wine Glass
|
54 |
+
36: Belt
|
55 |
+
37: Monitor/TV
|
56 |
+
38: Backpack
|
57 |
+
39: Umbrella
|
58 |
+
40: Traffic Light
|
59 |
+
41: Speaker
|
60 |
+
42: Watch
|
61 |
+
43: Tie
|
62 |
+
44: Trash bin Can
|
63 |
+
45: Slippers
|
64 |
+
46: Bicycle
|
65 |
+
47: Stool
|
66 |
+
48: Barrel/bucket
|
67 |
+
49: Van
|
68 |
+
50: Couch
|
69 |
+
51: Sandals
|
70 |
+
52: Basket
|
71 |
+
53: Drum
|
72 |
+
54: Pen/Pencil
|
73 |
+
55: Bus
|
74 |
+
56: Wild Bird
|
75 |
+
57: High Heels
|
76 |
+
58: Motorcycle
|
77 |
+
59: Guitar
|
78 |
+
60: Carpet
|
79 |
+
61: Cell Phone
|
80 |
+
62: Bread
|
81 |
+
63: Camera
|
82 |
+
64: Canned
|
83 |
+
65: Truck
|
84 |
+
66: Traffic cone
|
85 |
+
67: Cymbal
|
86 |
+
68: Lifesaver
|
87 |
+
69: Towel
|
88 |
+
70: Stuffed Toy
|
89 |
+
71: Candle
|
90 |
+
72: Sailboat
|
91 |
+
73: Laptop
|
92 |
+
74: Awning
|
93 |
+
75: Bed
|
94 |
+
76: Faucet
|
95 |
+
77: Tent
|
96 |
+
78: Horse
|
97 |
+
79: Mirror
|
98 |
+
80: Power outlet
|
99 |
+
81: Sink
|
100 |
+
82: Apple
|
101 |
+
83: Air Conditioner
|
102 |
+
84: Knife
|
103 |
+
85: Hockey Stick
|
104 |
+
86: Paddle
|
105 |
+
87: Pickup Truck
|
106 |
+
88: Fork
|
107 |
+
89: Traffic Sign
|
108 |
+
90: Balloon
|
109 |
+
91: Tripod
|
110 |
+
92: Dog
|
111 |
+
93: Spoon
|
112 |
+
94: Clock
|
113 |
+
95: Pot
|
114 |
+
96: Cow
|
115 |
+
97: Cake
|
116 |
+
98: Dinning Table
|
117 |
+
99: Sheep
|
118 |
+
100: Hanger
|
119 |
+
101: Blackboard/Whiteboard
|
120 |
+
102: Napkin
|
121 |
+
103: Other Fish
|
122 |
+
104: Orange/Tangerine
|
123 |
+
105: Toiletry
|
124 |
+
106: Keyboard
|
125 |
+
107: Tomato
|
126 |
+
108: Lantern
|
127 |
+
109: Machinery Vehicle
|
128 |
+
110: Fan
|
129 |
+
111: Green Vegetables
|
130 |
+
112: Banana
|
131 |
+
113: Baseball Glove
|
132 |
+
114: Airplane
|
133 |
+
115: Mouse
|
134 |
+
116: Train
|
135 |
+
117: Pumpkin
|
136 |
+
118: Soccer
|
137 |
+
119: Skiboard
|
138 |
+
120: Luggage
|
139 |
+
121: Nightstand
|
140 |
+
122: Tea pot
|
141 |
+
123: Telephone
|
142 |
+
124: Trolley
|
143 |
+
125: Head Phone
|
144 |
+
126: Sports Car
|
145 |
+
127: Stop Sign
|
146 |
+
128: Dessert
|
147 |
+
129: Scooter
|
148 |
+
130: Stroller
|
149 |
+
131: Crane
|
150 |
+
132: Remote
|
151 |
+
133: Refrigerator
|
152 |
+
134: Oven
|
153 |
+
135: Lemon
|
154 |
+
136: Duck
|
155 |
+
137: Baseball Bat
|
156 |
+
138: Surveillance Camera
|
157 |
+
139: Cat
|
158 |
+
140: Jug
|
159 |
+
141: Broccoli
|
160 |
+
142: Piano
|
161 |
+
143: Pizza
|
162 |
+
144: Elephant
|
163 |
+
145: Skateboard
|
164 |
+
146: Surfboard
|
165 |
+
147: Gun
|
166 |
+
148: Skating and Skiing shoes
|
167 |
+
149: Gas stove
|
168 |
+
150: Donut
|
169 |
+
151: Bow Tie
|
170 |
+
152: Carrot
|
171 |
+
153: Toilet
|
172 |
+
154: Kite
|
173 |
+
155: Strawberry
|
174 |
+
156: Other Balls
|
175 |
+
157: Shovel
|
176 |
+
158: Pepper
|
177 |
+
159: Computer Box
|
178 |
+
160: Toilet Paper
|
179 |
+
161: Cleaning Products
|
180 |
+
162: Chopsticks
|
181 |
+
163: Microwave
|
182 |
+
164: Pigeon
|
183 |
+
165: Baseball
|
184 |
+
166: Cutting/chopping Board
|
185 |
+
167: Coffee Table
|
186 |
+
168: Side Table
|
187 |
+
169: Scissors
|
188 |
+
170: Marker
|
189 |
+
171: Pie
|
190 |
+
172: Ladder
|
191 |
+
173: Snowboard
|
192 |
+
174: Cookies
|
193 |
+
175: Radiator
|
194 |
+
176: Fire Hydrant
|
195 |
+
177: Basketball
|
196 |
+
178: Zebra
|
197 |
+
179: Grape
|
198 |
+
180: Giraffe
|
199 |
+
181: Potato
|
200 |
+
182: Sausage
|
201 |
+
183: Tricycle
|
202 |
+
184: Violin
|
203 |
+
185: Egg
|
204 |
+
186: Fire Extinguisher
|
205 |
+
187: Candy
|
206 |
+
188: Fire Truck
|
207 |
+
189: Billiards
|
208 |
+
190: Converter
|
209 |
+
191: Bathtub
|
210 |
+
192: Wheelchair
|
211 |
+
193: Golf Club
|
212 |
+
194: Briefcase
|
213 |
+
195: Cucumber
|
214 |
+
196: Cigar/Cigarette
|
215 |
+
197: Paint Brush
|
216 |
+
198: Pear
|
217 |
+
199: Heavy Truck
|
218 |
+
200: Hamburger
|
219 |
+
201: Extractor
|
220 |
+
202: Extension Cord
|
221 |
+
203: Tong
|
222 |
+
204: Tennis Racket
|
223 |
+
205: Folder
|
224 |
+
206: American Football
|
225 |
+
207: earphone
|
226 |
+
208: Mask
|
227 |
+
209: Kettle
|
228 |
+
210: Tennis
|
229 |
+
211: Ship
|
230 |
+
212: Swing
|
231 |
+
213: Coffee Machine
|
232 |
+
214: Slide
|
233 |
+
215: Carriage
|
234 |
+
216: Onion
|
235 |
+
217: Green beans
|
236 |
+
218: Projector
|
237 |
+
219: Frisbee
|
238 |
+
220: Washing Machine/Drying Machine
|
239 |
+
221: Chicken
|
240 |
+
222: Printer
|
241 |
+
223: Watermelon
|
242 |
+
224: Saxophone
|
243 |
+
225: Tissue
|
244 |
+
226: Toothbrush
|
245 |
+
227: Ice cream
|
246 |
+
228: Hot-air balloon
|
247 |
+
229: Cello
|
248 |
+
230: French Fries
|
249 |
+
231: Scale
|
250 |
+
232: Trophy
|
251 |
+
233: Cabbage
|
252 |
+
234: Hot dog
|
253 |
+
235: Blender
|
254 |
+
236: Peach
|
255 |
+
237: Rice
|
256 |
+
238: Wallet/Purse
|
257 |
+
239: Volleyball
|
258 |
+
240: Deer
|
259 |
+
241: Goose
|
260 |
+
242: Tape
|
261 |
+
243: Tablet
|
262 |
+
244: Cosmetics
|
263 |
+
245: Trumpet
|
264 |
+
246: Pineapple
|
265 |
+
247: Golf Ball
|
266 |
+
248: Ambulance
|
267 |
+
249: Parking meter
|
268 |
+
250: Mango
|
269 |
+
251: Key
|
270 |
+
252: Hurdle
|
271 |
+
253: Fishing Rod
|
272 |
+
254: Medal
|
273 |
+
255: Flute
|
274 |
+
256: Brush
|
275 |
+
257: Penguin
|
276 |
+
258: Megaphone
|
277 |
+
259: Corn
|
278 |
+
260: Lettuce
|
279 |
+
261: Garlic
|
280 |
+
262: Swan
|
281 |
+
263: Helicopter
|
282 |
+
264: Green Onion
|
283 |
+
265: Sandwich
|
284 |
+
266: Nuts
|
285 |
+
267: Speed Limit Sign
|
286 |
+
268: Induction Cooker
|
287 |
+
269: Broom
|
288 |
+
270: Trombone
|
289 |
+
271: Plum
|
290 |
+
272: Rickshaw
|
291 |
+
273: Goldfish
|
292 |
+
274: Kiwi fruit
|
293 |
+
275: Router/modem
|
294 |
+
276: Poker Card
|
295 |
+
277: Toaster
|
296 |
+
278: Shrimp
|
297 |
+
279: Sushi
|
298 |
+
280: Cheese
|
299 |
+
281: Notepaper
|
300 |
+
282: Cherry
|
301 |
+
283: Pliers
|
302 |
+
284: CD
|
303 |
+
285: Pasta
|
304 |
+
286: Hammer
|
305 |
+
287: Cue
|
306 |
+
288: Avocado
|
307 |
+
289: Hamimelon
|
308 |
+
290: Flask
|
309 |
+
291: Mushroom
|
310 |
+
292: Screwdriver
|
311 |
+
293: Soap
|
312 |
+
294: Recorder
|
313 |
+
295: Bear
|
314 |
+
296: Eggplant
|
315 |
+
297: Board Eraser
|
316 |
+
298: Coconut
|
317 |
+
299: Tape Measure/Ruler
|
318 |
+
300: Pig
|
319 |
+
301: Showerhead
|
320 |
+
302: Globe
|
321 |
+
303: Chips
|
322 |
+
304: Steak
|
323 |
+
305: Crosswalk Sign
|
324 |
+
306: Stapler
|
325 |
+
307: Camel
|
326 |
+
308: Formula 1
|
327 |
+
309: Pomegranate
|
328 |
+
310: Dishwasher
|
329 |
+
311: Crab
|
330 |
+
312: Hoverboard
|
331 |
+
313: Meat ball
|
332 |
+
314: Rice Cooker
|
333 |
+
315: Tuba
|
334 |
+
316: Calculator
|
335 |
+
317: Papaya
|
336 |
+
318: Antelope
|
337 |
+
319: Parrot
|
338 |
+
320: Seal
|
339 |
+
321: Butterfly
|
340 |
+
322: Dumbbell
|
341 |
+
323: Donkey
|
342 |
+
324: Lion
|
343 |
+
325: Urinal
|
344 |
+
326: Dolphin
|
345 |
+
327: Electric Drill
|
346 |
+
328: Hair Dryer
|
347 |
+
329: Egg tart
|
348 |
+
330: Jellyfish
|
349 |
+
331: Treadmill
|
350 |
+
332: Lighter
|
351 |
+
333: Grapefruit
|
352 |
+
334: Game board
|
353 |
+
335: Mop
|
354 |
+
336: Radish
|
355 |
+
337: Baozi
|
356 |
+
338: Target
|
357 |
+
339: French
|
358 |
+
340: Spring Rolls
|
359 |
+
341: Monkey
|
360 |
+
342: Rabbit
|
361 |
+
343: Pencil Case
|
362 |
+
344: Yak
|
363 |
+
345: Red Cabbage
|
364 |
+
346: Binoculars
|
365 |
+
347: Asparagus
|
366 |
+
348: Barbell
|
367 |
+
349: Scallop
|
368 |
+
350: Noddles
|
369 |
+
351: Comb
|
370 |
+
352: Dumpling
|
371 |
+
353: Oyster
|
372 |
+
354: Table Tennis paddle
|
373 |
+
355: Cosmetics Brush/Eyeliner Pencil
|
374 |
+
356: Chainsaw
|
375 |
+
357: Eraser
|
376 |
+
358: Lobster
|
377 |
+
359: Durian
|
378 |
+
360: Okra
|
379 |
+
361: Lipstick
|
380 |
+
362: Cosmetics Mirror
|
381 |
+
363: Curling
|
382 |
+
364: Table Tennis
|
383 |
+
|
384 |
+
|
385 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
386 |
+
download: |
|
387 |
+
from tqdm import tqdm
|
388 |
+
|
389 |
+
from ultralytics.yolo.utils.checks import check_requirements
|
390 |
+
from ultralytics.yolo.utils.downloads import download
|
391 |
+
from ultralytics.yolo.utils.ops import xyxy2xywhn
|
392 |
+
|
393 |
+
import numpy as np
|
394 |
+
from pathlib import Path
|
395 |
+
|
396 |
+
check_requirements(('pycocotools>=2.0',))
|
397 |
+
from pycocotools.coco import COCO
|
398 |
+
|
399 |
+
# Make Directories
|
400 |
+
dir = Path(yaml['path']) # dataset root dir
|
401 |
+
for p in 'images', 'labels':
|
402 |
+
(dir / p).mkdir(parents=True, exist_ok=True)
|
403 |
+
for q in 'train', 'val':
|
404 |
+
(dir / p / q).mkdir(parents=True, exist_ok=True)
|
405 |
+
|
406 |
+
# Train, Val Splits
|
407 |
+
for split, patches in [('train', 50 + 1), ('val', 43 + 1)]:
|
408 |
+
print(f"Processing {split} in {patches} patches ...")
|
409 |
+
images, labels = dir / 'images' / split, dir / 'labels' / split
|
410 |
+
|
411 |
+
# Download
|
412 |
+
url = f"https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/{split}/"
|
413 |
+
if split == 'train':
|
414 |
+
download([f'{url}zhiyuan_objv2_{split}.tar.gz'], dir=dir) # annotations json
|
415 |
+
download([f'{url}patch{i}.tar.gz' for i in range(patches)], dir=images, curl=True, threads=8)
|
416 |
+
elif split == 'val':
|
417 |
+
download([f'{url}zhiyuan_objv2_{split}.json'], dir=dir) # annotations json
|
418 |
+
download([f'{url}images/v1/patch{i}.tar.gz' for i in range(15 + 1)], dir=images, curl=True, threads=8)
|
419 |
+
download([f'{url}images/v2/patch{i}.tar.gz' for i in range(16, patches)], dir=images, curl=True, threads=8)
|
420 |
+
|
421 |
+
# Move
|
422 |
+
for f in tqdm(images.rglob('*.jpg'), desc=f'Moving {split} images'):
|
423 |
+
f.rename(images / f.name) # move to /images/{split}
|
424 |
+
|
425 |
+
# Labels
|
426 |
+
coco = COCO(dir / f'zhiyuan_objv2_{split}.json')
|
427 |
+
names = [x["name"] for x in coco.loadCats(coco.getCatIds())]
|
428 |
+
for cid, cat in enumerate(names):
|
429 |
+
catIds = coco.getCatIds(catNms=[cat])
|
430 |
+
imgIds = coco.getImgIds(catIds=catIds)
|
431 |
+
for im in tqdm(coco.loadImgs(imgIds), desc=f'Class {cid + 1}/{len(names)} {cat}'):
|
432 |
+
width, height = im["width"], im["height"]
|
433 |
+
path = Path(im["file_name"]) # image filename
|
434 |
+
try:
|
435 |
+
with open(labels / path.with_suffix('.txt').name, 'a') as file:
|
436 |
+
annIds = coco.getAnnIds(imgIds=im["id"], catIds=catIds, iscrowd=None)
|
437 |
+
for a in coco.loadAnns(annIds):
|
438 |
+
x, y, w, h = a['bbox'] # bounding box in xywh (xy top-left corner)
|
439 |
+
xyxy = np.array([x, y, x + w, y + h])[None] # pixels(1,4)
|
440 |
+
x, y, w, h = xyxy2xywhn(xyxy, w=width, h=height, clip=True)[0] # normalized and clipped
|
441 |
+
file.write(f"{cid} {x:.5f} {y:.5f} {w:.5f} {h:.5f}\n")
|
442 |
+
except Exception as e:
|
443 |
+
print(e)
|
ultralytics/datasets/SKU-110K.yaml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# SKU-110K retail items dataset https://github.com/eg4000/SKU110K_CVPR19 by Trax Retail
|
3 |
+
# Example usage: yolo train data=SKU-110K.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── SKU-110K ← downloads here (13.6 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/SKU-110K # dataset root dir
|
12 |
+
train: train.txt # train images (relative to 'path') 8219 images
|
13 |
+
val: val.txt # val images (relative to 'path') 588 images
|
14 |
+
test: test.txt # test images (optional) 2936 images
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: object
|
19 |
+
|
20 |
+
|
21 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
22 |
+
download: |
|
23 |
+
import shutil
|
24 |
+
from pathlib import Path
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import pandas as pd
|
28 |
+
from tqdm import tqdm
|
29 |
+
|
30 |
+
from ultralytics.yolo.utils.downloads import download
|
31 |
+
from ultralytics.yolo.utils.ops import xyxy2xywh
|
32 |
+
|
33 |
+
# Download
|
34 |
+
dir = Path(yaml['path']) # dataset root dir
|
35 |
+
parent = Path(dir.parent) # download dir
|
36 |
+
urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz']
|
37 |
+
download(urls, dir=parent)
|
38 |
+
|
39 |
+
# Rename directories
|
40 |
+
if dir.exists():
|
41 |
+
shutil.rmtree(dir)
|
42 |
+
(parent / 'SKU110K_fixed').rename(dir) # rename dir
|
43 |
+
(dir / 'labels').mkdir(parents=True, exist_ok=True) # create labels dir
|
44 |
+
|
45 |
+
# Convert labels
|
46 |
+
names = 'image', 'x1', 'y1', 'x2', 'y2', 'class', 'image_width', 'image_height' # column names
|
47 |
+
for d in 'annotations_train.csv', 'annotations_val.csv', 'annotations_test.csv':
|
48 |
+
x = pd.read_csv(dir / 'annotations' / d, names=names).values # annotations
|
49 |
+
images, unique_images = x[:, 0], np.unique(x[:, 0])
|
50 |
+
with open((dir / d).with_suffix('.txt').__str__().replace('annotations_', ''), 'w') as f:
|
51 |
+
f.writelines(f'./images/{s}\n' for s in unique_images)
|
52 |
+
for im in tqdm(unique_images, desc=f'Converting {dir / d}'):
|
53 |
+
cls = 0 # single-class dataset
|
54 |
+
with open((dir / 'labels' / im).with_suffix('.txt'), 'a') as f:
|
55 |
+
for r in x[images == im]:
|
56 |
+
w, h = r[6], r[7] # image width, height
|
57 |
+
xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0] # instance
|
58 |
+
f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label
|
ultralytics/datasets/VOC.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC by University of Oxford
|
3 |
+
# Example usage: yolo train data=VOC.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── VOC ← downloads here (2.8 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/VOC
|
12 |
+
train: # train images (relative to 'path') 16551 images
|
13 |
+
- images/train2012
|
14 |
+
- images/train2007
|
15 |
+
- images/val2012
|
16 |
+
- images/val2007
|
17 |
+
val: # val images (relative to 'path') 4952 images
|
18 |
+
- images/test2007
|
19 |
+
test: # test images (optional)
|
20 |
+
- images/test2007
|
21 |
+
|
22 |
+
# Classes
|
23 |
+
names:
|
24 |
+
0: aeroplane
|
25 |
+
1: bicycle
|
26 |
+
2: bird
|
27 |
+
3: boat
|
28 |
+
4: bottle
|
29 |
+
5: bus
|
30 |
+
6: car
|
31 |
+
7: cat
|
32 |
+
8: chair
|
33 |
+
9: cow
|
34 |
+
10: diningtable
|
35 |
+
11: dog
|
36 |
+
12: horse
|
37 |
+
13: motorbike
|
38 |
+
14: person
|
39 |
+
15: pottedplant
|
40 |
+
16: sheep
|
41 |
+
17: sofa
|
42 |
+
18: train
|
43 |
+
19: tvmonitor
|
44 |
+
|
45 |
+
|
46 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
47 |
+
download: |
|
48 |
+
import xml.etree.ElementTree as ET
|
49 |
+
|
50 |
+
from tqdm import tqdm
|
51 |
+
from ultralytics.yolo.utils.downloads import download
|
52 |
+
from pathlib import Path
|
53 |
+
|
54 |
+
def convert_label(path, lb_path, year, image_id):
|
55 |
+
def convert_box(size, box):
|
56 |
+
dw, dh = 1. / size[0], 1. / size[1]
|
57 |
+
x, y, w, h = (box[0] + box[1]) / 2.0 - 1, (box[2] + box[3]) / 2.0 - 1, box[1] - box[0], box[3] - box[2]
|
58 |
+
return x * dw, y * dh, w * dw, h * dh
|
59 |
+
|
60 |
+
in_file = open(path / f'VOC{year}/Annotations/{image_id}.xml')
|
61 |
+
out_file = open(lb_path, 'w')
|
62 |
+
tree = ET.parse(in_file)
|
63 |
+
root = tree.getroot()
|
64 |
+
size = root.find('size')
|
65 |
+
w = int(size.find('width').text)
|
66 |
+
h = int(size.find('height').text)
|
67 |
+
|
68 |
+
names = list(yaml['names'].values()) # names list
|
69 |
+
for obj in root.iter('object'):
|
70 |
+
cls = obj.find('name').text
|
71 |
+
if cls in names and int(obj.find('difficult').text) != 1:
|
72 |
+
xmlbox = obj.find('bndbox')
|
73 |
+
bb = convert_box((w, h), [float(xmlbox.find(x).text) for x in ('xmin', 'xmax', 'ymin', 'ymax')])
|
74 |
+
cls_id = names.index(cls) # class id
|
75 |
+
out_file.write(" ".join([str(a) for a in (cls_id, *bb)]) + '\n')
|
76 |
+
|
77 |
+
|
78 |
+
# Download
|
79 |
+
dir = Path(yaml['path']) # dataset root dir
|
80 |
+
url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
|
81 |
+
urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images
|
82 |
+
f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images
|
83 |
+
f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images
|
84 |
+
download(urls, dir=dir / 'images', curl=True, threads=3)
|
85 |
+
|
86 |
+
# Convert
|
87 |
+
path = dir / 'images/VOCdevkit'
|
88 |
+
for year, image_set in ('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test'):
|
89 |
+
imgs_path = dir / 'images' / f'{image_set}{year}'
|
90 |
+
lbs_path = dir / 'labels' / f'{image_set}{year}'
|
91 |
+
imgs_path.mkdir(exist_ok=True, parents=True)
|
92 |
+
lbs_path.mkdir(exist_ok=True, parents=True)
|
93 |
+
|
94 |
+
with open(path / f'VOC{year}/ImageSets/Main/{image_set}.txt') as f:
|
95 |
+
image_ids = f.read().strip().split()
|
96 |
+
for id in tqdm(image_ids, desc=f'{image_set}{year}'):
|
97 |
+
f = path / f'VOC{year}/JPEGImages/{id}.jpg' # old img path
|
98 |
+
lb_path = (lbs_path / f.name).with_suffix('.txt') # new label path
|
99 |
+
f.rename(imgs_path / f.name) # move image
|
100 |
+
convert_label(path, lb_path, year, id) # convert labels to YOLO format
|
ultralytics/datasets/VisDrone.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# VisDrone2019-DET dataset https://github.com/VisDrone/VisDrone-Dataset by Tianjin University
|
3 |
+
# Example usage: yolo train data=VisDrone.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── VisDrone ← downloads here (2.3 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/VisDrone # dataset root dir
|
12 |
+
train: VisDrone2019-DET-train/images # train images (relative to 'path') 6471 images
|
13 |
+
val: VisDrone2019-DET-val/images # val images (relative to 'path') 548 images
|
14 |
+
test: VisDrone2019-DET-test-dev/images # test images (optional) 1610 images
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: pedestrian
|
19 |
+
1: people
|
20 |
+
2: bicycle
|
21 |
+
3: car
|
22 |
+
4: van
|
23 |
+
5: truck
|
24 |
+
6: tricycle
|
25 |
+
7: awning-tricycle
|
26 |
+
8: bus
|
27 |
+
9: motor
|
28 |
+
|
29 |
+
|
30 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
31 |
+
download: |
|
32 |
+
import os
|
33 |
+
from pathlib import Path
|
34 |
+
|
35 |
+
from ultralytics.yolo.utils.downloads import download
|
36 |
+
|
37 |
+
def visdrone2yolo(dir):
|
38 |
+
from PIL import Image
|
39 |
+
from tqdm import tqdm
|
40 |
+
|
41 |
+
def convert_box(size, box):
|
42 |
+
# Convert VisDrone box to YOLO xywh box
|
43 |
+
dw = 1. / size[0]
|
44 |
+
dh = 1. / size[1]
|
45 |
+
return (box[0] + box[2] / 2) * dw, (box[1] + box[3] / 2) * dh, box[2] * dw, box[3] * dh
|
46 |
+
|
47 |
+
(dir / 'labels').mkdir(parents=True, exist_ok=True) # make labels directory
|
48 |
+
pbar = tqdm((dir / 'annotations').glob('*.txt'), desc=f'Converting {dir}')
|
49 |
+
for f in pbar:
|
50 |
+
img_size = Image.open((dir / 'images' / f.name).with_suffix('.jpg')).size
|
51 |
+
lines = []
|
52 |
+
with open(f, 'r') as file: # read annotation.txt
|
53 |
+
for row in [x.split(',') for x in file.read().strip().splitlines()]:
|
54 |
+
if row[4] == '0': # VisDrone 'ignored regions' class 0
|
55 |
+
continue
|
56 |
+
cls = int(row[5]) - 1
|
57 |
+
box = convert_box(img_size, tuple(map(int, row[:4])))
|
58 |
+
lines.append(f"{cls} {' '.join(f'{x:.6f}' for x in box)}\n")
|
59 |
+
with open(str(f).replace(f'{os.sep}annotations{os.sep}', f'{os.sep}labels{os.sep}'), 'w') as fl:
|
60 |
+
fl.writelines(lines) # write label.txt
|
61 |
+
|
62 |
+
|
63 |
+
# Download
|
64 |
+
dir = Path(yaml['path']) # dataset root dir
|
65 |
+
urls = ['https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-train.zip',
|
66 |
+
'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-val.zip',
|
67 |
+
'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-dev.zip',
|
68 |
+
'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-challenge.zip']
|
69 |
+
download(urls, dir=dir, curl=True, threads=4)
|
70 |
+
|
71 |
+
# Convert
|
72 |
+
for d in 'VisDrone2019-DET-train', 'VisDrone2019-DET-val', 'VisDrone2019-DET-test-dev':
|
73 |
+
visdrone2yolo(dir / d) # convert VisDrone annotations to YOLO labels
|
ultralytics/datasets/coco-pose.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO 2017 dataset http://cocodataset.org by Microsoft
|
3 |
+
# Example usage: yolo train data=coco-pose.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco-pose ← downloads here (20.1 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco-pose # dataset root dir
|
12 |
+
train: train2017.txt # train images (relative to 'path') 118287 images
|
13 |
+
val: val2017.txt # val images (relative to 'path') 5000 images
|
14 |
+
test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
|
15 |
+
|
16 |
+
# Keypoints
|
17 |
+
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
18 |
+
flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
19 |
+
|
20 |
+
# Classes
|
21 |
+
names:
|
22 |
+
0: person
|
23 |
+
|
24 |
+
# Download script/URL (optional)
|
25 |
+
download: |
|
26 |
+
from ultralytics.yolo.utils.downloads import download
|
27 |
+
from pathlib import Path
|
28 |
+
|
29 |
+
# Download labels
|
30 |
+
dir = Path(yaml['path']) # dataset root dir
|
31 |
+
url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
|
32 |
+
urls = [url + 'coco2017labels-pose.zip'] # labels
|
33 |
+
download(urls, dir=dir.parent)
|
34 |
+
# Download data
|
35 |
+
urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
|
36 |
+
'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
|
37 |
+
'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
|
38 |
+
download(urls, dir=dir / 'images', threads=3)
|
ultralytics/datasets/coco.yaml
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO 2017 dataset http://cocodataset.org by Microsoft
|
3 |
+
# Example usage: yolo train data=coco.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco ← downloads here (20.1 GB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco # dataset root dir
|
12 |
+
train: train2017.txt # train images (relative to 'path') 118287 images
|
13 |
+
val: val2017.txt # val images (relative to 'path') 5000 images
|
14 |
+
test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: airplane
|
23 |
+
5: bus
|
24 |
+
6: train
|
25 |
+
7: truck
|
26 |
+
8: boat
|
27 |
+
9: traffic light
|
28 |
+
10: fire hydrant
|
29 |
+
11: stop sign
|
30 |
+
12: parking meter
|
31 |
+
13: bench
|
32 |
+
14: bird
|
33 |
+
15: cat
|
34 |
+
16: dog
|
35 |
+
17: horse
|
36 |
+
18: sheep
|
37 |
+
19: cow
|
38 |
+
20: elephant
|
39 |
+
21: bear
|
40 |
+
22: zebra
|
41 |
+
23: giraffe
|
42 |
+
24: backpack
|
43 |
+
25: umbrella
|
44 |
+
26: handbag
|
45 |
+
27: tie
|
46 |
+
28: suitcase
|
47 |
+
29: frisbee
|
48 |
+
30: skis
|
49 |
+
31: snowboard
|
50 |
+
32: sports ball
|
51 |
+
33: kite
|
52 |
+
34: baseball bat
|
53 |
+
35: baseball glove
|
54 |
+
36: skateboard
|
55 |
+
37: surfboard
|
56 |
+
38: tennis racket
|
57 |
+
39: bottle
|
58 |
+
40: wine glass
|
59 |
+
41: cup
|
60 |
+
42: fork
|
61 |
+
43: knife
|
62 |
+
44: spoon
|
63 |
+
45: bowl
|
64 |
+
46: banana
|
65 |
+
47: apple
|
66 |
+
48: sandwich
|
67 |
+
49: orange
|
68 |
+
50: broccoli
|
69 |
+
51: carrot
|
70 |
+
52: hot dog
|
71 |
+
53: pizza
|
72 |
+
54: donut
|
73 |
+
55: cake
|
74 |
+
56: chair
|
75 |
+
57: couch
|
76 |
+
58: potted plant
|
77 |
+
59: bed
|
78 |
+
60: dining table
|
79 |
+
61: toilet
|
80 |
+
62: tv
|
81 |
+
63: laptop
|
82 |
+
64: mouse
|
83 |
+
65: remote
|
84 |
+
66: keyboard
|
85 |
+
67: cell phone
|
86 |
+
68: microwave
|
87 |
+
69: oven
|
88 |
+
70: toaster
|
89 |
+
71: sink
|
90 |
+
72: refrigerator
|
91 |
+
73: book
|
92 |
+
74: clock
|
93 |
+
75: vase
|
94 |
+
76: scissors
|
95 |
+
77: teddy bear
|
96 |
+
78: hair drier
|
97 |
+
79: toothbrush
|
98 |
+
|
99 |
+
|
100 |
+
# Download script/URL (optional)
|
101 |
+
download: |
|
102 |
+
from ultralytics.yolo.utils.downloads import download
|
103 |
+
from pathlib import Path
|
104 |
+
|
105 |
+
# Download labels
|
106 |
+
segments = True # segment or box labels
|
107 |
+
dir = Path(yaml['path']) # dataset root dir
|
108 |
+
url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
|
109 |
+
urls = [url + ('coco2017labels-segments.zip' if segments else 'coco2017labels.zip')] # labels
|
110 |
+
download(urls, dir=dir.parent)
|
111 |
+
# Download data
|
112 |
+
urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
|
113 |
+
'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
|
114 |
+
'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
|
115 |
+
download(urls, dir=dir / 'images', threads=3)
|
ultralytics/datasets/coco128-seg.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO128-seg dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics
|
3 |
+
# Example usage: yolo train data=coco128.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco128-seg ← downloads here (7 MB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco128-seg # dataset root dir
|
12 |
+
train: images/train2017 # train images (relative to 'path') 128 images
|
13 |
+
val: images/train2017 # val images (relative to 'path') 128 images
|
14 |
+
test: # test images (optional)
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: airplane
|
23 |
+
5: bus
|
24 |
+
6: train
|
25 |
+
7: truck
|
26 |
+
8: boat
|
27 |
+
9: traffic light
|
28 |
+
10: fire hydrant
|
29 |
+
11: stop sign
|
30 |
+
12: parking meter
|
31 |
+
13: bench
|
32 |
+
14: bird
|
33 |
+
15: cat
|
34 |
+
16: dog
|
35 |
+
17: horse
|
36 |
+
18: sheep
|
37 |
+
19: cow
|
38 |
+
20: elephant
|
39 |
+
21: bear
|
40 |
+
22: zebra
|
41 |
+
23: giraffe
|
42 |
+
24: backpack
|
43 |
+
25: umbrella
|
44 |
+
26: handbag
|
45 |
+
27: tie
|
46 |
+
28: suitcase
|
47 |
+
29: frisbee
|
48 |
+
30: skis
|
49 |
+
31: snowboard
|
50 |
+
32: sports ball
|
51 |
+
33: kite
|
52 |
+
34: baseball bat
|
53 |
+
35: baseball glove
|
54 |
+
36: skateboard
|
55 |
+
37: surfboard
|
56 |
+
38: tennis racket
|
57 |
+
39: bottle
|
58 |
+
40: wine glass
|
59 |
+
41: cup
|
60 |
+
42: fork
|
61 |
+
43: knife
|
62 |
+
44: spoon
|
63 |
+
45: bowl
|
64 |
+
46: banana
|
65 |
+
47: apple
|
66 |
+
48: sandwich
|
67 |
+
49: orange
|
68 |
+
50: broccoli
|
69 |
+
51: carrot
|
70 |
+
52: hot dog
|
71 |
+
53: pizza
|
72 |
+
54: donut
|
73 |
+
55: cake
|
74 |
+
56: chair
|
75 |
+
57: couch
|
76 |
+
58: potted plant
|
77 |
+
59: bed
|
78 |
+
60: dining table
|
79 |
+
61: toilet
|
80 |
+
62: tv
|
81 |
+
63: laptop
|
82 |
+
64: mouse
|
83 |
+
65: remote
|
84 |
+
66: keyboard
|
85 |
+
67: cell phone
|
86 |
+
68: microwave
|
87 |
+
69: oven
|
88 |
+
70: toaster
|
89 |
+
71: sink
|
90 |
+
72: refrigerator
|
91 |
+
73: book
|
92 |
+
74: clock
|
93 |
+
75: vase
|
94 |
+
76: scissors
|
95 |
+
77: teddy bear
|
96 |
+
78: hair drier
|
97 |
+
79: toothbrush
|
98 |
+
|
99 |
+
|
100 |
+
# Download script/URL (optional)
|
101 |
+
download: https://ultralytics.com/assets/coco128-seg.zip
|
ultralytics/datasets/coco128.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics
|
3 |
+
# Example usage: yolo train data=coco128.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco128 ← downloads here (7 MB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco128 # dataset root dir
|
12 |
+
train: images/train2017 # train images (relative to 'path') 128 images
|
13 |
+
val: images/train2017 # val images (relative to 'path') 128 images
|
14 |
+
test: # test images (optional)
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: airplane
|
23 |
+
5: bus
|
24 |
+
6: train
|
25 |
+
7: truck
|
26 |
+
8: boat
|
27 |
+
9: traffic light
|
28 |
+
10: fire hydrant
|
29 |
+
11: stop sign
|
30 |
+
12: parking meter
|
31 |
+
13: bench
|
32 |
+
14: bird
|
33 |
+
15: cat
|
34 |
+
16: dog
|
35 |
+
17: horse
|
36 |
+
18: sheep
|
37 |
+
19: cow
|
38 |
+
20: elephant
|
39 |
+
21: bear
|
40 |
+
22: zebra
|
41 |
+
23: giraffe
|
42 |
+
24: backpack
|
43 |
+
25: umbrella
|
44 |
+
26: handbag
|
45 |
+
27: tie
|
46 |
+
28: suitcase
|
47 |
+
29: frisbee
|
48 |
+
30: skis
|
49 |
+
31: snowboard
|
50 |
+
32: sports ball
|
51 |
+
33: kite
|
52 |
+
34: baseball bat
|
53 |
+
35: baseball glove
|
54 |
+
36: skateboard
|
55 |
+
37: surfboard
|
56 |
+
38: tennis racket
|
57 |
+
39: bottle
|
58 |
+
40: wine glass
|
59 |
+
41: cup
|
60 |
+
42: fork
|
61 |
+
43: knife
|
62 |
+
44: spoon
|
63 |
+
45: bowl
|
64 |
+
46: banana
|
65 |
+
47: apple
|
66 |
+
48: sandwich
|
67 |
+
49: orange
|
68 |
+
50: broccoli
|
69 |
+
51: carrot
|
70 |
+
52: hot dog
|
71 |
+
53: pizza
|
72 |
+
54: donut
|
73 |
+
55: cake
|
74 |
+
56: chair
|
75 |
+
57: couch
|
76 |
+
58: potted plant
|
77 |
+
59: bed
|
78 |
+
60: dining table
|
79 |
+
61: toilet
|
80 |
+
62: tv
|
81 |
+
63: laptop
|
82 |
+
64: mouse
|
83 |
+
65: remote
|
84 |
+
66: keyboard
|
85 |
+
67: cell phone
|
86 |
+
68: microwave
|
87 |
+
69: oven
|
88 |
+
70: toaster
|
89 |
+
71: sink
|
90 |
+
72: refrigerator
|
91 |
+
73: book
|
92 |
+
74: clock
|
93 |
+
75: vase
|
94 |
+
76: scissors
|
95 |
+
77: teddy bear
|
96 |
+
78: hair drier
|
97 |
+
79: toothbrush
|
98 |
+
|
99 |
+
|
100 |
+
# Download script/URL (optional)
|
101 |
+
download: https://ultralytics.com/assets/coco128.zip
|
ultralytics/datasets/coco8-pose.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO8-pose dataset (first 8 images from COCO train2017) by Ultralytics
|
3 |
+
# Example usage: yolo train data=coco8-pose.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco8-pose ← downloads here (1 MB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco8-pose # dataset root dir
|
12 |
+
train: images/train # train images (relative to 'path') 4 images
|
13 |
+
val: images/val # val images (relative to 'path') 4 images
|
14 |
+
test: # test images (optional)
|
15 |
+
|
16 |
+
# Keypoints
|
17 |
+
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
18 |
+
flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
19 |
+
|
20 |
+
# Classes
|
21 |
+
names:
|
22 |
+
0: person
|
23 |
+
|
24 |
+
# Download script/URL (optional)
|
25 |
+
download: https://ultralytics.com/assets/coco8-pose.zip
|
ultralytics/datasets/coco8-seg.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO8-seg dataset (first 8 images from COCO train2017) by Ultralytics
|
3 |
+
# Example usage: yolo train data=coco8-seg.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco8-seg ← downloads here (1 MB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco8-seg # dataset root dir
|
12 |
+
train: images/train # train images (relative to 'path') 4 images
|
13 |
+
val: images/val # val images (relative to 'path') 4 images
|
14 |
+
test: # test images (optional)
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: airplane
|
23 |
+
5: bus
|
24 |
+
6: train
|
25 |
+
7: truck
|
26 |
+
8: boat
|
27 |
+
9: traffic light
|
28 |
+
10: fire hydrant
|
29 |
+
11: stop sign
|
30 |
+
12: parking meter
|
31 |
+
13: bench
|
32 |
+
14: bird
|
33 |
+
15: cat
|
34 |
+
16: dog
|
35 |
+
17: horse
|
36 |
+
18: sheep
|
37 |
+
19: cow
|
38 |
+
20: elephant
|
39 |
+
21: bear
|
40 |
+
22: zebra
|
41 |
+
23: giraffe
|
42 |
+
24: backpack
|
43 |
+
25: umbrella
|
44 |
+
26: handbag
|
45 |
+
27: tie
|
46 |
+
28: suitcase
|
47 |
+
29: frisbee
|
48 |
+
30: skis
|
49 |
+
31: snowboard
|
50 |
+
32: sports ball
|
51 |
+
33: kite
|
52 |
+
34: baseball bat
|
53 |
+
35: baseball glove
|
54 |
+
36: skateboard
|
55 |
+
37: surfboard
|
56 |
+
38: tennis racket
|
57 |
+
39: bottle
|
58 |
+
40: wine glass
|
59 |
+
41: cup
|
60 |
+
42: fork
|
61 |
+
43: knife
|
62 |
+
44: spoon
|
63 |
+
45: bowl
|
64 |
+
46: banana
|
65 |
+
47: apple
|
66 |
+
48: sandwich
|
67 |
+
49: orange
|
68 |
+
50: broccoli
|
69 |
+
51: carrot
|
70 |
+
52: hot dog
|
71 |
+
53: pizza
|
72 |
+
54: donut
|
73 |
+
55: cake
|
74 |
+
56: chair
|
75 |
+
57: couch
|
76 |
+
58: potted plant
|
77 |
+
59: bed
|
78 |
+
60: dining table
|
79 |
+
61: toilet
|
80 |
+
62: tv
|
81 |
+
63: laptop
|
82 |
+
64: mouse
|
83 |
+
65: remote
|
84 |
+
66: keyboard
|
85 |
+
67: cell phone
|
86 |
+
68: microwave
|
87 |
+
69: oven
|
88 |
+
70: toaster
|
89 |
+
71: sink
|
90 |
+
72: refrigerator
|
91 |
+
73: book
|
92 |
+
74: clock
|
93 |
+
75: vase
|
94 |
+
76: scissors
|
95 |
+
77: teddy bear
|
96 |
+
78: hair drier
|
97 |
+
79: toothbrush
|
98 |
+
|
99 |
+
|
100 |
+
# Download script/URL (optional)
|
101 |
+
download: https://ultralytics.com/assets/coco8-seg.zip
|
ultralytics/datasets/coco8.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# COCO8 dataset (first 8 images from COCO train2017) by Ultralytics
|
3 |
+
# Example usage: yolo train data=coco8.yaml
|
4 |
+
# parent
|
5 |
+
# ├── ultralytics
|
6 |
+
# └── datasets
|
7 |
+
# └── coco8 ← downloads here (1 MB)
|
8 |
+
|
9 |
+
|
10 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
11 |
+
path: ../datasets/coco8 # dataset root dir
|
12 |
+
train: images/train # train images (relative to 'path') 4 images
|
13 |
+
val: images/val # val images (relative to 'path') 4 images
|
14 |
+
test: # test images (optional)
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: airplane
|
23 |
+
5: bus
|
24 |
+
6: train
|
25 |
+
7: truck
|
26 |
+
8: boat
|
27 |
+
9: traffic light
|
28 |
+
10: fire hydrant
|
29 |
+
11: stop sign
|
30 |
+
12: parking meter
|
31 |
+
13: bench
|
32 |
+
14: bird
|
33 |
+
15: cat
|
34 |
+
16: dog
|
35 |
+
17: horse
|
36 |
+
18: sheep
|
37 |
+
19: cow
|
38 |
+
20: elephant
|
39 |
+
21: bear
|
40 |
+
22: zebra
|
41 |
+
23: giraffe
|
42 |
+
24: backpack
|
43 |
+
25: umbrella
|
44 |
+
26: handbag
|
45 |
+
27: tie
|
46 |
+
28: suitcase
|
47 |
+
29: frisbee
|
48 |
+
30: skis
|
49 |
+
31: snowboard
|
50 |
+
32: sports ball
|
51 |
+
33: kite
|
52 |
+
34: baseball bat
|
53 |
+
35: baseball glove
|
54 |
+
36: skateboard
|
55 |
+
37: surfboard
|
56 |
+
38: tennis racket
|
57 |
+
39: bottle
|
58 |
+
40: wine glass
|
59 |
+
41: cup
|
60 |
+
42: fork
|
61 |
+
43: knife
|
62 |
+
44: spoon
|
63 |
+
45: bowl
|
64 |
+
46: banana
|
65 |
+
47: apple
|
66 |
+
48: sandwich
|
67 |
+
49: orange
|
68 |
+
50: broccoli
|
69 |
+
51: carrot
|
70 |
+
52: hot dog
|
71 |
+
53: pizza
|
72 |
+
54: donut
|
73 |
+
55: cake
|
74 |
+
56: chair
|
75 |
+
57: couch
|
76 |
+
58: potted plant
|
77 |
+
59: bed
|
78 |
+
60: dining table
|
79 |
+
61: toilet
|
80 |
+
62: tv
|
81 |
+
63: laptop
|
82 |
+
64: mouse
|
83 |
+
65: remote
|
84 |
+
66: keyboard
|
85 |
+
67: cell phone
|
86 |
+
68: microwave
|
87 |
+
69: oven
|
88 |
+
70: toaster
|
89 |
+
71: sink
|
90 |
+
72: refrigerator
|
91 |
+
73: book
|
92 |
+
74: clock
|
93 |
+
75: vase
|
94 |
+
76: scissors
|
95 |
+
77: teddy bear
|
96 |
+
78: hair drier
|
97 |
+
79: toothbrush
|
98 |
+
|
99 |
+
|
100 |
+
# Download script/URL (optional)
|
101 |
+
download: https://ultralytics.com/assets/coco8.zip
|
ultralytics/datasets/xView.yaml
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# DIUx xView 2018 Challenge https://challenge.xviewdataset.org by U.S. National Geospatial-Intelligence Agency (NGA)
|
3 |
+
# -------- DOWNLOAD DATA MANUALLY and jar xf val_images.zip to 'datasets/xView' before running train command! --------
|
4 |
+
# Example usage: yolo train data=xView.yaml
|
5 |
+
# parent
|
6 |
+
# ├── ultralytics
|
7 |
+
# └── datasets
|
8 |
+
# └── xView ← downloads here (20.7 GB)
|
9 |
+
|
10 |
+
|
11 |
+
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
12 |
+
path: ../datasets/xView # dataset root dir
|
13 |
+
train: images/autosplit_train.txt # train images (relative to 'path') 90% of 847 train images
|
14 |
+
val: images/autosplit_val.txt # train images (relative to 'path') 10% of 847 train images
|
15 |
+
|
16 |
+
# Classes
|
17 |
+
names:
|
18 |
+
0: Fixed-wing Aircraft
|
19 |
+
1: Small Aircraft
|
20 |
+
2: Cargo Plane
|
21 |
+
3: Helicopter
|
22 |
+
4: Passenger Vehicle
|
23 |
+
5: Small Car
|
24 |
+
6: Bus
|
25 |
+
7: Pickup Truck
|
26 |
+
8: Utility Truck
|
27 |
+
9: Truck
|
28 |
+
10: Cargo Truck
|
29 |
+
11: Truck w/Box
|
30 |
+
12: Truck Tractor
|
31 |
+
13: Trailer
|
32 |
+
14: Truck w/Flatbed
|
33 |
+
15: Truck w/Liquid
|
34 |
+
16: Crane Truck
|
35 |
+
17: Railway Vehicle
|
36 |
+
18: Passenger Car
|
37 |
+
19: Cargo Car
|
38 |
+
20: Flat Car
|
39 |
+
21: Tank car
|
40 |
+
22: Locomotive
|
41 |
+
23: Maritime Vessel
|
42 |
+
24: Motorboat
|
43 |
+
25: Sailboat
|
44 |
+
26: Tugboat
|
45 |
+
27: Barge
|
46 |
+
28: Fishing Vessel
|
47 |
+
29: Ferry
|
48 |
+
30: Yacht
|
49 |
+
31: Container Ship
|
50 |
+
32: Oil Tanker
|
51 |
+
33: Engineering Vehicle
|
52 |
+
34: Tower crane
|
53 |
+
35: Container Crane
|
54 |
+
36: Reach Stacker
|
55 |
+
37: Straddle Carrier
|
56 |
+
38: Mobile Crane
|
57 |
+
39: Dump Truck
|
58 |
+
40: Haul Truck
|
59 |
+
41: Scraper/Tractor
|
60 |
+
42: Front loader/Bulldozer
|
61 |
+
43: Excavator
|
62 |
+
44: Cement Mixer
|
63 |
+
45: Ground Grader
|
64 |
+
46: Hut/Tent
|
65 |
+
47: Shed
|
66 |
+
48: Building
|
67 |
+
49: Aircraft Hangar
|
68 |
+
50: Damaged Building
|
69 |
+
51: Facility
|
70 |
+
52: Construction Site
|
71 |
+
53: Vehicle Lot
|
72 |
+
54: Helipad
|
73 |
+
55: Storage Tank
|
74 |
+
56: Shipping container lot
|
75 |
+
57: Shipping Container
|
76 |
+
58: Pylon
|
77 |
+
59: Tower
|
78 |
+
|
79 |
+
|
80 |
+
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
81 |
+
download: |
|
82 |
+
import json
|
83 |
+
import os
|
84 |
+
from pathlib import Path
|
85 |
+
|
86 |
+
import numpy as np
|
87 |
+
from PIL import Image
|
88 |
+
from tqdm import tqdm
|
89 |
+
|
90 |
+
from ultralytics.yolo.data.dataloaders.v5loader import autosplit
|
91 |
+
from ultralytics.yolo.utils.ops import xyxy2xywhn
|
92 |
+
|
93 |
+
|
94 |
+
def convert_labels(fname=Path('xView/xView_train.geojson')):
|
95 |
+
# Convert xView geoJSON labels to YOLO format
|
96 |
+
path = fname.parent
|
97 |
+
with open(fname) as f:
|
98 |
+
print(f'Loading {fname}...')
|
99 |
+
data = json.load(f)
|
100 |
+
|
101 |
+
# Make dirs
|
102 |
+
labels = Path(path / 'labels' / 'train')
|
103 |
+
os.system(f'rm -rf {labels}')
|
104 |
+
labels.mkdir(parents=True, exist_ok=True)
|
105 |
+
|
106 |
+
# xView classes 11-94 to 0-59
|
107 |
+
xview_class2index = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, -1, 3, -1, 4, 5, 6, 7, 8, -1, 9, 10, 11,
|
108 |
+
12, 13, 14, 15, -1, -1, 16, 17, 18, 19, 20, 21, 22, -1, 23, 24, 25, -1, 26, 27, -1, 28, -1,
|
109 |
+
29, 30, 31, 32, 33, 34, 35, 36, 37, -1, 38, 39, 40, 41, 42, 43, 44, 45, -1, -1, -1, -1, 46,
|
110 |
+
47, 48, 49, -1, 50, 51, -1, 52, -1, -1, -1, 53, 54, -1, 55, -1, -1, 56, -1, 57, -1, 58, 59]
|
111 |
+
|
112 |
+
shapes = {}
|
113 |
+
for feature in tqdm(data['features'], desc=f'Converting {fname}'):
|
114 |
+
p = feature['properties']
|
115 |
+
if p['bounds_imcoords']:
|
116 |
+
id = p['image_id']
|
117 |
+
file = path / 'train_images' / id
|
118 |
+
if file.exists(): # 1395.tif missing
|
119 |
+
try:
|
120 |
+
box = np.array([int(num) for num in p['bounds_imcoords'].split(",")])
|
121 |
+
assert box.shape[0] == 4, f'incorrect box shape {box.shape[0]}'
|
122 |
+
cls = p['type_id']
|
123 |
+
cls = xview_class2index[int(cls)] # xView class to 0-60
|
124 |
+
assert 59 >= cls >= 0, f'incorrect class index {cls}'
|
125 |
+
|
126 |
+
# Write YOLO label
|
127 |
+
if id not in shapes:
|
128 |
+
shapes[id] = Image.open(file).size
|
129 |
+
box = xyxy2xywhn(box[None].astype(np.float), w=shapes[id][0], h=shapes[id][1], clip=True)
|
130 |
+
with open((labels / id).with_suffix('.txt'), 'a') as f:
|
131 |
+
f.write(f"{cls} {' '.join(f'{x:.6f}' for x in box[0])}\n") # write label.txt
|
132 |
+
except Exception as e:
|
133 |
+
print(f'WARNING: skipping one label for {file}: {e}')
|
134 |
+
|
135 |
+
|
136 |
+
# Download manually from https://challenge.xviewdataset.org
|
137 |
+
dir = Path(yaml['path']) # dataset root dir
|
138 |
+
# urls = ['https://d307kc0mrhucc3.cloudfront.net/train_labels.zip', # train labels
|
139 |
+
# 'https://d307kc0mrhucc3.cloudfront.net/train_images.zip', # 15G, 847 train images
|
140 |
+
# 'https://d307kc0mrhucc3.cloudfront.net/val_images.zip'] # 5G, 282 val images (no labels)
|
141 |
+
# download(urls, dir=dir)
|
142 |
+
|
143 |
+
# Convert labels
|
144 |
+
convert_labels(dir / 'xView_train.geojson')
|
145 |
+
|
146 |
+
# Move images
|
147 |
+
images = Path(dir / 'images')
|
148 |
+
images.mkdir(parents=True, exist_ok=True)
|
149 |
+
Path(dir / 'train_images').rename(dir / 'images' / 'train')
|
150 |
+
Path(dir / 'val_images').rename(dir / 'images' / 'val')
|
151 |
+
|
152 |
+
# Split
|
153 |
+
autosplit(dir / 'images' / 'train')
|
ultralytics/hub/__init__.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import requests
|
4 |
+
|
5 |
+
from ultralytics.hub.auth import Auth
|
6 |
+
from ultralytics.hub.utils import PREFIX
|
7 |
+
from ultralytics.yolo.data.utils import HUBDatasetStats
|
8 |
+
from ultralytics.yolo.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
|
9 |
+
|
10 |
+
|
11 |
+
def login(api_key=''):
|
12 |
+
"""
|
13 |
+
Log in to the Ultralytics HUB API using the provided API key.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
17 |
+
|
18 |
+
Example:
|
19 |
+
from ultralytics import hub
|
20 |
+
hub.login('API_KEY')
|
21 |
+
"""
|
22 |
+
Auth(api_key, verbose=True)
|
23 |
+
|
24 |
+
|
25 |
+
def logout():
|
26 |
+
"""
|
27 |
+
Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo hub login'.
|
28 |
+
|
29 |
+
Example:
|
30 |
+
from ultralytics import hub
|
31 |
+
hub.logout()
|
32 |
+
"""
|
33 |
+
SETTINGS['api_key'] = ''
|
34 |
+
yaml_save(USER_CONFIG_DIR / 'settings.yaml', SETTINGS)
|
35 |
+
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
|
36 |
+
|
37 |
+
|
38 |
+
def start(key=''):
|
39 |
+
"""
|
40 |
+
Start training models with Ultralytics HUB (DEPRECATED).
|
41 |
+
|
42 |
+
Args:
|
43 |
+
key (str, optional): A string containing either the API key and model ID combination (apikey_modelid),
|
44 |
+
or the full model URL (https://hub.ultralytics.com/models/apikey_modelid).
|
45 |
+
"""
|
46 |
+
api_key, model_id = key.split('_')
|
47 |
+
LOGGER.warning(f"""
|
48 |
+
WARNING ⚠️ ultralytics.start() is deprecated after 8.0.60. Updated usage to train Ultralytics HUB models is:
|
49 |
+
|
50 |
+
from ultralytics import YOLO, hub
|
51 |
+
|
52 |
+
hub.login('{api_key}')
|
53 |
+
model = YOLO('https://hub.ultralytics.com/models/{model_id}')
|
54 |
+
model.train()""")
|
55 |
+
|
56 |
+
|
57 |
+
def reset_model(model_id=''):
|
58 |
+
"""Reset a trained model to an untrained state."""
|
59 |
+
r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
|
60 |
+
if r.status_code == 200:
|
61 |
+
LOGGER.info(f'{PREFIX}Model reset successfully')
|
62 |
+
return
|
63 |
+
LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
|
64 |
+
|
65 |
+
|
66 |
+
def export_fmts_hub():
|
67 |
+
"""Returns a list of HUB-supported export formats."""
|
68 |
+
from ultralytics.yolo.engine.exporter import export_formats
|
69 |
+
return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
|
70 |
+
|
71 |
+
|
72 |
+
def export_model(model_id='', format='torchscript'):
|
73 |
+
"""Export a model to all formats."""
|
74 |
+
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
75 |
+
r = requests.post(f'https://api.ultralytics.com/v1/models/{model_id}/export',
|
76 |
+
json={'format': format},
|
77 |
+
headers={'x-api-key': Auth().api_key})
|
78 |
+
assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
|
79 |
+
LOGGER.info(f'{PREFIX}{format} export started ✅')
|
80 |
+
|
81 |
+
|
82 |
+
def get_export(model_id='', format='torchscript'):
|
83 |
+
"""Get an exported model dictionary with download URL."""
|
84 |
+
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
85 |
+
r = requests.post('https://api.ultralytics.com/get-export',
|
86 |
+
json={
|
87 |
+
'apiKey': Auth().api_key,
|
88 |
+
'modelId': model_id,
|
89 |
+
'format': format})
|
90 |
+
assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
|
91 |
+
return r.json()
|
92 |
+
|
93 |
+
|
94 |
+
def check_dataset(path='', task='detect'):
|
95 |
+
"""
|
96 |
+
Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is
|
97 |
+
uploaded to the HUB. Usage examples are given below.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
path (str, optional): Path to data.zip (with data.yaml inside data.zip). Defaults to ''.
|
101 |
+
task (str, optional): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Defaults to 'detect'.
|
102 |
+
|
103 |
+
Example:
|
104 |
+
```python
|
105 |
+
from ultralytics.hub import check_dataset
|
106 |
+
|
107 |
+
check_dataset('path/to/coco8.zip', task='detect') # detect dataset
|
108 |
+
check_dataset('path/to/coco8-seg.zip', task='segment') # segment dataset
|
109 |
+
check_dataset('path/to/coco8-pose.zip', task='pose') # pose dataset
|
110 |
+
```
|
111 |
+
"""
|
112 |
+
HUBDatasetStats(path=path, task=task).get_json()
|
113 |
+
LOGGER.info('Checks completed correctly ✅. Upload this dataset to https://hub.ultralytics.com/datasets/.')
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == '__main__':
|
117 |
+
start()
|
ultralytics/hub/auth.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import requests
|
4 |
+
|
5 |
+
from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, request_with_credentials
|
6 |
+
from ultralytics.yolo.utils import LOGGER, SETTINGS, emojis, is_colab, set_settings
|
7 |
+
|
8 |
+
API_KEY_URL = 'https://hub.ultralytics.com/settings?tab=api+keys'
|
9 |
+
|
10 |
+
|
11 |
+
class Auth:
|
12 |
+
id_token = api_key = model_key = False
|
13 |
+
|
14 |
+
def __init__(self, api_key='', verbose=False):
|
15 |
+
"""
|
16 |
+
Initialize the Auth class with an optional API key.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
20 |
+
"""
|
21 |
+
# Split the input API key in case it contains a combined key_model and keep only the API key part
|
22 |
+
api_key = api_key.split('_')[0]
|
23 |
+
|
24 |
+
# Set API key attribute as value passed or SETTINGS API key if none passed
|
25 |
+
self.api_key = api_key or SETTINGS.get('api_key', '')
|
26 |
+
|
27 |
+
# If an API key is provided
|
28 |
+
if self.api_key:
|
29 |
+
# If the provided API key matches the API key in the SETTINGS
|
30 |
+
if self.api_key == SETTINGS.get('api_key'):
|
31 |
+
# Log that the user is already logged in
|
32 |
+
if verbose:
|
33 |
+
LOGGER.info(f'{PREFIX}Authenticated ✅')
|
34 |
+
return
|
35 |
+
else:
|
36 |
+
# Attempt to authenticate with the provided API key
|
37 |
+
success = self.authenticate()
|
38 |
+
# If the API key is not provided and the environment is a Google Colab notebook
|
39 |
+
elif is_colab():
|
40 |
+
# Attempt to authenticate using browser cookies
|
41 |
+
success = self.auth_with_cookies()
|
42 |
+
else:
|
43 |
+
# Request an API key
|
44 |
+
success = self.request_api_key()
|
45 |
+
|
46 |
+
# Update SETTINGS with the new API key after successful authentication
|
47 |
+
if success:
|
48 |
+
set_settings({'api_key': self.api_key})
|
49 |
+
# Log that the new login was successful
|
50 |
+
if verbose:
|
51 |
+
LOGGER.info(f'{PREFIX}New authentication successful ✅')
|
52 |
+
elif verbose:
|
53 |
+
LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}')
|
54 |
+
|
55 |
+
def request_api_key(self, max_attempts=3):
|
56 |
+
"""
|
57 |
+
Prompt the user to input their API key. Returns the model ID.
|
58 |
+
"""
|
59 |
+
import getpass
|
60 |
+
for attempts in range(max_attempts):
|
61 |
+
LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
|
62 |
+
input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ')
|
63 |
+
self.api_key = input_key.split('_')[0] # remove model id if present
|
64 |
+
if self.authenticate():
|
65 |
+
return True
|
66 |
+
raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
|
67 |
+
|
68 |
+
def authenticate(self) -> bool:
|
69 |
+
"""
|
70 |
+
Attempt to authenticate with the server using either id_token or API key.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
bool: True if authentication is successful, False otherwise.
|
74 |
+
"""
|
75 |
+
try:
|
76 |
+
header = self.get_auth_header()
|
77 |
+
if header:
|
78 |
+
r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
|
79 |
+
if not r.json().get('success', False):
|
80 |
+
raise ConnectionError('Unable to authenticate.')
|
81 |
+
return True
|
82 |
+
raise ConnectionError('User has not authenticated locally.')
|
83 |
+
except ConnectionError:
|
84 |
+
self.id_token = self.api_key = False # reset invalid
|
85 |
+
LOGGER.warning(f'{PREFIX}Invalid API key ⚠️')
|
86 |
+
return False
|
87 |
+
|
88 |
+
def auth_with_cookies(self) -> bool:
|
89 |
+
"""
|
90 |
+
Attempt to fetch authentication via cookies and set id_token.
|
91 |
+
User must be logged in to HUB and running in a supported browser.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
bool: True if authentication is successful, False otherwise.
|
95 |
+
"""
|
96 |
+
if not is_colab():
|
97 |
+
return False # Currently only works with Colab
|
98 |
+
try:
|
99 |
+
authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
|
100 |
+
if authn.get('success', False):
|
101 |
+
self.id_token = authn.get('data', {}).get('idToken', None)
|
102 |
+
self.authenticate()
|
103 |
+
return True
|
104 |
+
raise ConnectionError('Unable to fetch browser authentication details.')
|
105 |
+
except ConnectionError:
|
106 |
+
self.id_token = False # reset invalid
|
107 |
+
return False
|
108 |
+
|
109 |
+
def get_auth_header(self):
|
110 |
+
"""
|
111 |
+
Get the authentication header for making API requests.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
(dict): The authentication header if id_token or API key is set, None otherwise.
|
115 |
+
"""
|
116 |
+
if self.id_token:
|
117 |
+
return {'authorization': f'Bearer {self.id_token}'}
|
118 |
+
elif self.api_key:
|
119 |
+
return {'x-api-key': self.api_key}
|
120 |
+
else:
|
121 |
+
return None
|
122 |
+
|
123 |
+
def get_state(self) -> bool:
|
124 |
+
"""
|
125 |
+
Get the authentication state.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
bool: True if either id_token or API key is set, False otherwise.
|
129 |
+
"""
|
130 |
+
return self.id_token or self.api_key
|
131 |
+
|
132 |
+
def set_api_key(self, key: str):
|
133 |
+
"""
|
134 |
+
Set the API key for authentication.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
key (str): The API key string.
|
138 |
+
"""
|
139 |
+
self.api_key = key
|
ultralytics/hub/session.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
import signal
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
from time import sleep
|
6 |
+
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, smart_request
|
10 |
+
from ultralytics.yolo.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
|
11 |
+
from ultralytics.yolo.utils.errors import HUBModelError
|
12 |
+
|
13 |
+
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
|
14 |
+
|
15 |
+
|
16 |
+
class HUBTrainingSession:
|
17 |
+
"""
|
18 |
+
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
url (str): Model identifier used to initialize the HUB training session.
|
22 |
+
|
23 |
+
Attributes:
|
24 |
+
agent_id (str): Identifier for the instance communicating with the server.
|
25 |
+
model_id (str): Identifier for the YOLOv5 model being trained.
|
26 |
+
model_url (str): URL for the model in Ultralytics HUB.
|
27 |
+
api_url (str): API URL for the model in Ultralytics HUB.
|
28 |
+
auth_header (Dict): Authentication header for the Ultralytics HUB API requests.
|
29 |
+
rate_limits (Dict): Rate limits for different API calls (in seconds).
|
30 |
+
timers (Dict): Timers for rate limiting.
|
31 |
+
metrics_queue (Dict): Queue for the model's metrics.
|
32 |
+
model (Dict): Model data fetched from Ultralytics HUB.
|
33 |
+
alive (bool): Indicates if the heartbeat loop is active.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, url):
|
37 |
+
"""
|
38 |
+
Initialize the HUBTrainingSession with the provided model identifier.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
url (str): Model identifier used to initialize the HUB training session.
|
42 |
+
It can be a URL string or a model key with specific format.
|
43 |
+
|
44 |
+
Raises:
|
45 |
+
ValueError: If the provided model identifier is invalid.
|
46 |
+
ConnectionError: If connecting with global API key is not supported.
|
47 |
+
"""
|
48 |
+
|
49 |
+
from ultralytics.hub.auth import Auth
|
50 |
+
|
51 |
+
# Parse input
|
52 |
+
if url.startswith('https://hub.ultralytics.com/models/'):
|
53 |
+
url = url.split('https://hub.ultralytics.com/models/')[-1]
|
54 |
+
if [len(x) for x in url.split('_')] == [42, 20]:
|
55 |
+
key, model_id = url.split('_')
|
56 |
+
elif len(url) == 20:
|
57 |
+
key, model_id = '', url
|
58 |
+
else:
|
59 |
+
raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
|
60 |
+
f"model='https://hub.ultralytics.com/models/MODEL_ID' and try again.")
|
61 |
+
|
62 |
+
# Authorize
|
63 |
+
auth = Auth(key)
|
64 |
+
self.agent_id = None # identifies which instance is communicating with server
|
65 |
+
self.model_id = model_id
|
66 |
+
self.model_url = f'https://hub.ultralytics.com/models/{model_id}'
|
67 |
+
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
|
68 |
+
self.auth_header = auth.get_auth_header()
|
69 |
+
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
|
70 |
+
self.timers = {} # rate limit timers (seconds)
|
71 |
+
self.metrics_queue = {} # metrics queue
|
72 |
+
self.model = self._get_model()
|
73 |
+
self.alive = True
|
74 |
+
self._start_heartbeat() # start heartbeats
|
75 |
+
self._register_signal_handlers()
|
76 |
+
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
|
77 |
+
|
78 |
+
def _register_signal_handlers(self):
|
79 |
+
"""Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
|
80 |
+
signal.signal(signal.SIGTERM, self._handle_signal)
|
81 |
+
signal.signal(signal.SIGINT, self._handle_signal)
|
82 |
+
|
83 |
+
def _handle_signal(self, signum, frame):
|
84 |
+
"""
|
85 |
+
Handle kill signals and prevent heartbeats from being sent on Colab after termination.
|
86 |
+
This method does not use frame, it is included as it is passed by signal.
|
87 |
+
"""
|
88 |
+
if self.alive is True:
|
89 |
+
LOGGER.info(f'{PREFIX}Kill signal received! ❌')
|
90 |
+
self._stop_heartbeat()
|
91 |
+
sys.exit(signum)
|
92 |
+
|
93 |
+
def _stop_heartbeat(self):
|
94 |
+
"""Terminate the heartbeat loop."""
|
95 |
+
self.alive = False
|
96 |
+
|
97 |
+
def upload_metrics(self):
|
98 |
+
"""Upload model metrics to Ultralytics HUB."""
|
99 |
+
payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
|
100 |
+
smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
|
101 |
+
|
102 |
+
def _get_model(self):
|
103 |
+
"""Fetch and return model data from Ultralytics HUB."""
|
104 |
+
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
|
105 |
+
|
106 |
+
try:
|
107 |
+
response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0)
|
108 |
+
data = response.json().get('data', None)
|
109 |
+
|
110 |
+
if data.get('status', None) == 'trained':
|
111 |
+
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
|
112 |
+
|
113 |
+
if not data.get('data', None):
|
114 |
+
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
|
115 |
+
self.model_id = data['id']
|
116 |
+
|
117 |
+
if data['status'] == 'new': # new model to start training
|
118 |
+
self.train_args = {
|
119 |
+
# TODO: deprecate 'batch_size' key for 'batch' in 3Q23
|
120 |
+
'batch': data['batch' if ('batch' in data) else 'batch_size'],
|
121 |
+
'epochs': data['epochs'],
|
122 |
+
'imgsz': data['imgsz'],
|
123 |
+
'patience': data['patience'],
|
124 |
+
'device': data['device'],
|
125 |
+
'cache': data['cache'],
|
126 |
+
'data': data['data']}
|
127 |
+
self.model_file = data.get('cfg') or data.get('weights') # cfg for pretrained=False
|
128 |
+
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
|
129 |
+
elif data['status'] == 'training': # existing model to resume training
|
130 |
+
self.train_args = {'data': data['data'], 'resume': True}
|
131 |
+
self.model_file = data['resume']
|
132 |
+
|
133 |
+
return data
|
134 |
+
except requests.exceptions.ConnectionError as e:
|
135 |
+
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
|
136 |
+
except Exception:
|
137 |
+
raise
|
138 |
+
|
139 |
+
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
140 |
+
"""
|
141 |
+
Upload a model checkpoint to Ultralytics HUB.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
epoch (int): The current training epoch.
|
145 |
+
weights (str): Path to the model weights file.
|
146 |
+
is_best (bool): Indicates if the current model is the best one so far.
|
147 |
+
map (float): Mean average precision of the model.
|
148 |
+
final (bool): Indicates if the model is the final model after training.
|
149 |
+
"""
|
150 |
+
if Path(weights).is_file():
|
151 |
+
with open(weights, 'rb') as f:
|
152 |
+
file = f.read()
|
153 |
+
else:
|
154 |
+
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
|
155 |
+
file = None
|
156 |
+
url = f'{self.api_url}/upload'
|
157 |
+
# url = 'http://httpbin.org/post' # for debug
|
158 |
+
data = {'epoch': epoch}
|
159 |
+
if final:
|
160 |
+
data.update({'type': 'final', 'map': map})
|
161 |
+
smart_request('post',
|
162 |
+
url,
|
163 |
+
data=data,
|
164 |
+
files={'best.pt': file},
|
165 |
+
headers=self.auth_header,
|
166 |
+
retry=10,
|
167 |
+
timeout=3600,
|
168 |
+
thread=False,
|
169 |
+
progress=True,
|
170 |
+
code=4)
|
171 |
+
else:
|
172 |
+
data.update({'type': 'epoch', 'isBest': bool(is_best)})
|
173 |
+
smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3)
|
174 |
+
|
175 |
+
@threaded
|
176 |
+
def _start_heartbeat(self):
|
177 |
+
"""Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
|
178 |
+
while self.alive:
|
179 |
+
r = smart_request('post',
|
180 |
+
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
181 |
+
json={
|
182 |
+
'agent': AGENT_NAME,
|
183 |
+
'agentId': self.agent_id},
|
184 |
+
headers=self.auth_header,
|
185 |
+
retry=0,
|
186 |
+
code=5,
|
187 |
+
thread=False) # already in a thread
|
188 |
+
self.agent_id = r.json().get('data', {}).get('agentId', None)
|
189 |
+
sleep(self.rate_limits['heartbeat'])
|
ultralytics/hub/utils.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import os
|
4 |
+
import platform
|
5 |
+
import random
|
6 |
+
import sys
|
7 |
+
import threading
|
8 |
+
import time
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import requests
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from ultralytics.yolo.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM_BAR_FORMAT,
|
15 |
+
TryExcept, __version__, colorstr, get_git_origin_url, is_colab, is_git_dir,
|
16 |
+
is_pip_package)
|
17 |
+
|
18 |
+
PREFIX = colorstr('Ultralytics HUB: ')
|
19 |
+
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
20 |
+
HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
|
21 |
+
|
22 |
+
|
23 |
+
def request_with_credentials(url: str) -> any:
|
24 |
+
"""
|
25 |
+
Make an AJAX request with cookies attached in a Google Colab environment.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
url (str): The URL to make the request to.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
(any): The response data from the AJAX request.
|
32 |
+
|
33 |
+
Raises:
|
34 |
+
OSError: If the function is not run in a Google Colab environment.
|
35 |
+
"""
|
36 |
+
if not is_colab():
|
37 |
+
raise OSError('request_with_credentials() must run in a Colab environment')
|
38 |
+
from google.colab import output # noqa
|
39 |
+
from IPython import display # noqa
|
40 |
+
display.display(
|
41 |
+
display.Javascript("""
|
42 |
+
window._hub_tmp = new Promise((resolve, reject) => {
|
43 |
+
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
|
44 |
+
fetch("%s", {
|
45 |
+
method: 'POST',
|
46 |
+
credentials: 'include'
|
47 |
+
})
|
48 |
+
.then((response) => resolve(response.json()))
|
49 |
+
.then((json) => {
|
50 |
+
clearTimeout(timeout);
|
51 |
+
}).catch((err) => {
|
52 |
+
clearTimeout(timeout);
|
53 |
+
reject(err);
|
54 |
+
});
|
55 |
+
});
|
56 |
+
""" % url))
|
57 |
+
return output.eval_js('_hub_tmp')
|
58 |
+
|
59 |
+
|
60 |
+
def requests_with_progress(method, url, **kwargs):
|
61 |
+
"""
|
62 |
+
Make an HTTP request using the specified method and URL, with an optional progress bar.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
method (str): The HTTP method to use (e.g. 'GET', 'POST').
|
66 |
+
url (str): The URL to send the request to.
|
67 |
+
**kwargs (dict): Additional keyword arguments to pass to the underlying `requests.request` function.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
(requests.Response): The response object from the HTTP request.
|
71 |
+
|
72 |
+
Note:
|
73 |
+
If 'progress' is set to True, the progress bar will display the download progress
|
74 |
+
for responses with a known content length.
|
75 |
+
"""
|
76 |
+
progress = kwargs.pop('progress', False)
|
77 |
+
if not progress:
|
78 |
+
return requests.request(method, url, **kwargs)
|
79 |
+
response = requests.request(method, url, stream=True, **kwargs)
|
80 |
+
total = int(response.headers.get('content-length', 0)) # total size
|
81 |
+
pbar = tqdm(total=total, unit='B', unit_scale=True, unit_divisor=1024, bar_format=TQDM_BAR_FORMAT)
|
82 |
+
for data in response.iter_content(chunk_size=1024):
|
83 |
+
pbar.update(len(data))
|
84 |
+
pbar.close()
|
85 |
+
return response
|
86 |
+
|
87 |
+
|
88 |
+
def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs):
|
89 |
+
"""
|
90 |
+
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
|
94 |
+
url (str): The URL to make the request to.
|
95 |
+
retry (int, optional): Number of retries to attempt before giving up. Default is 3.
|
96 |
+
timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30.
|
97 |
+
thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True.
|
98 |
+
code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
|
99 |
+
verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
|
100 |
+
progress (bool, optional): Whether to show a progress bar during the request. Default is False.
|
101 |
+
**kwargs (dict): Keyword arguments to be passed to the requests function specified in method.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
(requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None.
|
105 |
+
"""
|
106 |
+
retry_codes = (408, 500) # retry only these codes
|
107 |
+
|
108 |
+
@TryExcept(verbose=verbose)
|
109 |
+
def func(func_method, func_url, **func_kwargs):
|
110 |
+
"""Make HTTP requests with retries and timeouts, with optional progress tracking."""
|
111 |
+
r = None # response
|
112 |
+
t0 = time.time() # initial time for timer
|
113 |
+
for i in range(retry + 1):
|
114 |
+
if (time.time() - t0) > timeout:
|
115 |
+
break
|
116 |
+
r = requests_with_progress(func_method, func_url, **func_kwargs) # i.e. get(url, data, json, files)
|
117 |
+
if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful"
|
118 |
+
break
|
119 |
+
try:
|
120 |
+
m = r.json().get('message', 'No JSON message.')
|
121 |
+
except AttributeError:
|
122 |
+
m = 'Unable to read JSON.'
|
123 |
+
if i == 0:
|
124 |
+
if r.status_code in retry_codes:
|
125 |
+
m += f' Retrying {retry}x for {timeout}s.' if retry else ''
|
126 |
+
elif r.status_code == 429: # rate limit
|
127 |
+
h = r.headers # response headers
|
128 |
+
m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
|
129 |
+
f"Please retry after {h['Retry-After']}s."
|
130 |
+
if verbose:
|
131 |
+
LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
|
132 |
+
if r.status_code not in retry_codes:
|
133 |
+
return r
|
134 |
+
time.sleep(2 ** i) # exponential standoff
|
135 |
+
return r
|
136 |
+
|
137 |
+
args = method, url
|
138 |
+
kwargs['progress'] = progress
|
139 |
+
if thread:
|
140 |
+
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
|
141 |
+
else:
|
142 |
+
return func(*args, **kwargs)
|
143 |
+
|
144 |
+
|
145 |
+
class Events:
|
146 |
+
"""
|
147 |
+
A class for collecting anonymous event analytics. Event analytics are enabled when sync=True in settings and
|
148 |
+
disabled when sync=False. Run 'yolo settings' to see and update settings YAML file.
|
149 |
+
|
150 |
+
Attributes:
|
151 |
+
url (str): The URL to send anonymous events.
|
152 |
+
rate_limit (float): The rate limit in seconds for sending events.
|
153 |
+
metadata (dict): A dictionary containing metadata about the environment.
|
154 |
+
enabled (bool): A flag to enable or disable Events based on certain conditions.
|
155 |
+
"""
|
156 |
+
|
157 |
+
url = 'https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw'
|
158 |
+
|
159 |
+
def __init__(self):
|
160 |
+
"""
|
161 |
+
Initializes the Events object with default values for events, rate_limit, and metadata.
|
162 |
+
"""
|
163 |
+
self.events = [] # events list
|
164 |
+
self.rate_limit = 60.0 # rate limit (seconds)
|
165 |
+
self.t = 0.0 # rate limit timer (seconds)
|
166 |
+
self.metadata = {
|
167 |
+
'cli': Path(sys.argv[0]).name == 'yolo',
|
168 |
+
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
|
169 |
+
'python': '.'.join(platform.python_version_tuple()[:2]), # i.e. 3.10
|
170 |
+
'version': __version__,
|
171 |
+
'env': ENVIRONMENT,
|
172 |
+
'session_id': round(random.random() * 1E15),
|
173 |
+
'engagement_time_msec': 1000}
|
174 |
+
self.enabled = \
|
175 |
+
SETTINGS['sync'] and \
|
176 |
+
RANK in (-1, 0) and \
|
177 |
+
not TESTS_RUNNING and \
|
178 |
+
ONLINE and \
|
179 |
+
(is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
|
180 |
+
|
181 |
+
def __call__(self, cfg):
|
182 |
+
"""
|
183 |
+
Attempts to add a new event to the events list and send events if the rate limit is reached.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
|
187 |
+
"""
|
188 |
+
if not self.enabled:
|
189 |
+
# Events disabled, do nothing
|
190 |
+
return
|
191 |
+
|
192 |
+
# Attempt to add to events
|
193 |
+
if len(self.events) < 25: # Events list limited to 25 events (drop any events past this)
|
194 |
+
params = {**self.metadata, **{'task': cfg.task}}
|
195 |
+
if cfg.mode == 'export':
|
196 |
+
params['format'] = cfg.format
|
197 |
+
self.events.append({'name': cfg.mode, 'params': params})
|
198 |
+
|
199 |
+
# Check rate limit
|
200 |
+
t = time.time()
|
201 |
+
if (t - self.t) < self.rate_limit:
|
202 |
+
# Time is under rate limiter, wait to send
|
203 |
+
return
|
204 |
+
|
205 |
+
# Time is over rate limiter, send now
|
206 |
+
data = {'client_id': SETTINGS['uuid'], 'events': self.events} # SHA-256 anonymized UUID hash and events list
|
207 |
+
|
208 |
+
# POST equivalent to requests.post(self.url, json=data)
|
209 |
+
smart_request('post', self.url, json=data, retry=0, verbose=False)
|
210 |
+
|
211 |
+
# Reset events and rate limit timer
|
212 |
+
self.events = []
|
213 |
+
self.t = t
|
214 |
+
|
215 |
+
|
216 |
+
# Run below code on hub/utils init -------------------------------------------------------------------------------------
|
217 |
+
events = Events()
|
ultralytics/models/README.md
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Models
|
2 |
+
|
3 |
+
Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration
|
4 |
+
files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted
|
5 |
+
and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image
|
6 |
+
segmentation tasks.
|
7 |
+
|
8 |
+
These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like
|
9 |
+
instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms,
|
10 |
+
from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this
|
11 |
+
directory provides a great starting point for your custom model development needs.
|
12 |
+
|
13 |
+
To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've
|
14 |
+
selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full
|
15 |
+
details at the Ultralytics [Docs](https://docs.ultralytics.com/models), and if you need help or have any questions, feel free
|
16 |
+
to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now!
|
17 |
+
|
18 |
+
### Usage
|
19 |
+
|
20 |
+
Model `*.yaml` files may be used directly in the Command Line Interface (CLI) with a `yolo` command:
|
21 |
+
|
22 |
+
```bash
|
23 |
+
yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100
|
24 |
+
```
|
25 |
+
|
26 |
+
They may also be used directly in a Python environment, and accepts the same
|
27 |
+
[arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above:
|
28 |
+
|
29 |
+
```python
|
30 |
+
from ultralytics import YOLO
|
31 |
+
|
32 |
+
model = YOLO("model.yaml") # build a YOLOv8n model from scratch
|
33 |
+
# YOLO("model.pt") use pre-trained model if available
|
34 |
+
model.info() # display model information
|
35 |
+
model.train(data="coco128.yaml", epochs=100) # train the model
|
36 |
+
```
|
37 |
+
|
38 |
+
## Pre-trained Model Architectures
|
39 |
+
|
40 |
+
Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information
|
41 |
+
and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available.
|
42 |
+
|
43 |
+
## Contributing New Models
|
44 |
+
|
45 |
+
If you've developed a new model architecture or have improvements for existing models that you'd like to contribute to the Ultralytics community, please submit your contribution in a new Pull Request. For more details, visit our [Contributing Guide](https://docs.ultralytics.com/help/contributing).
|
ultralytics/models/rt-detr/rtdetr-l.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
l: [1.00, 1.00, 1024]
|
9 |
+
|
10 |
+
backbone:
|
11 |
+
# [from, repeats, module, args]
|
12 |
+
- [-1, 1, HGStem, [32, 48]] # 0-P2/4
|
13 |
+
- [-1, 6, HGBlock, [48, 128, 3]] # stage 1
|
14 |
+
|
15 |
+
- [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
|
16 |
+
- [-1, 6, HGBlock, [96, 512, 3]] # stage 2
|
17 |
+
|
18 |
+
- [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16
|
19 |
+
- [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
|
20 |
+
- [-1, 6, HGBlock, [192, 1024, 5, True, True]]
|
21 |
+
- [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3
|
22 |
+
|
23 |
+
- [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32
|
24 |
+
- [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4
|
25 |
+
|
26 |
+
head:
|
27 |
+
- [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
|
28 |
+
- [-1, 1, AIFI, [1024, 8]]
|
29 |
+
- [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0
|
30 |
+
|
31 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
32 |
+
- [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
|
33 |
+
- [[-2, -1], 1, Concat, [1]]
|
34 |
+
- [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
|
35 |
+
- [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1
|
36 |
+
|
37 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
38 |
+
- [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
|
39 |
+
- [[-2, -1], 1, Concat, [1]] # cat backbone P4
|
40 |
+
- [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1
|
41 |
+
|
42 |
+
- [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
|
43 |
+
- [[-1, 17], 1, Concat, [1]] # cat Y4
|
44 |
+
- [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0
|
45 |
+
|
46 |
+
- [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
|
47 |
+
- [[-1, 12], 1, Concat, [1]] # cat Y5
|
48 |
+
- [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1
|
49 |
+
|
50 |
+
- [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
|
ultralytics/models/rt-detr/rtdetr-x.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# RT-DETR-x object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
x: [1.00, 1.00, 2048]
|
9 |
+
|
10 |
+
backbone:
|
11 |
+
# [from, repeats, module, args]
|
12 |
+
- [-1, 1, HGStem, [32, 64]] # 0-P2/4
|
13 |
+
- [-1, 6, HGBlock, [64, 128, 3]] # stage 1
|
14 |
+
|
15 |
+
- [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
|
16 |
+
- [-1, 6, HGBlock, [128, 512, 3]]
|
17 |
+
- [-1, 6, HGBlock, [128, 512, 3, False, True]] # 4-stage 2
|
18 |
+
|
19 |
+
- [-1, 1, DWConv, [512, 3, 2, 1, False]] # 5-P3/16
|
20 |
+
- [-1, 6, HGBlock, [256, 1024, 5, True, False]] # cm, c2, k, light, shortcut
|
21 |
+
- [-1, 6, HGBlock, [256, 1024, 5, True, True]]
|
22 |
+
- [-1, 6, HGBlock, [256, 1024, 5, True, True]]
|
23 |
+
- [-1, 6, HGBlock, [256, 1024, 5, True, True]]
|
24 |
+
- [-1, 6, HGBlock, [256, 1024, 5, True, True]] # 10-stage 3
|
25 |
+
|
26 |
+
- [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 11-P4/32
|
27 |
+
- [-1, 6, HGBlock, [512, 2048, 5, True, False]]
|
28 |
+
- [-1, 6, HGBlock, [512, 2048, 5, True, True]] # 13-stage 4
|
29 |
+
|
30 |
+
head:
|
31 |
+
- [-1, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 14 input_proj.2
|
32 |
+
- [-1, 1, AIFI, [2048, 8]]
|
33 |
+
- [-1, 1, Conv, [384, 1, 1]] # 16, Y5, lateral_convs.0
|
34 |
+
|
35 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
36 |
+
- [10, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 18 input_proj.1
|
37 |
+
- [[-2, -1], 1, Concat, [1]]
|
38 |
+
- [-1, 3, RepC3, [384]] # 20, fpn_blocks.0
|
39 |
+
- [-1, 1, Conv, [384, 1, 1]] # 21, Y4, lateral_convs.1
|
40 |
+
|
41 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
42 |
+
- [4, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 23 input_proj.0
|
43 |
+
- [[-2, -1], 1, Concat, [1]] # cat backbone P4
|
44 |
+
- [-1, 3, RepC3, [384]] # X3 (25), fpn_blocks.1
|
45 |
+
|
46 |
+
- [-1, 1, Conv, [384, 3, 2]] # 26, downsample_convs.0
|
47 |
+
- [[-1, 21], 1, Concat, [1]] # cat Y4
|
48 |
+
- [-1, 3, RepC3, [384]] # F4 (28), pan_blocks.0
|
49 |
+
|
50 |
+
- [-1, 1, Conv, [384, 3, 2]] # 29, downsample_convs.1
|
51 |
+
- [[-1, 16], 1, Concat, [1]] # cat Y5
|
52 |
+
- [-1, 3, RepC3, [384]] # F5 (31), pan_blocks.1
|
53 |
+
|
54 |
+
- [[25, 28, 31], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
|
ultralytics/models/v3/yolov3-spp.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv3-SPP object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
depth_multiple: 1.0 # model depth multiple
|
7 |
+
width_multiple: 1.0 # layer channel multiple
|
8 |
+
|
9 |
+
# darknet53 backbone
|
10 |
+
backbone:
|
11 |
+
# [from, number, module, args]
|
12 |
+
[[-1, 1, Conv, [32, 3, 1]], # 0
|
13 |
+
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2
|
14 |
+
[-1, 1, Bottleneck, [64]],
|
15 |
+
[-1, 1, Conv, [128, 3, 2]], # 3-P2/4
|
16 |
+
[-1, 2, Bottleneck, [128]],
|
17 |
+
[-1, 1, Conv, [256, 3, 2]], # 5-P3/8
|
18 |
+
[-1, 8, Bottleneck, [256]],
|
19 |
+
[-1, 1, Conv, [512, 3, 2]], # 7-P4/16
|
20 |
+
[-1, 8, Bottleneck, [512]],
|
21 |
+
[-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
|
22 |
+
[-1, 4, Bottleneck, [1024]], # 10
|
23 |
+
]
|
24 |
+
|
25 |
+
# YOLOv3-SPP head
|
26 |
+
head:
|
27 |
+
[[-1, 1, Bottleneck, [1024, False]],
|
28 |
+
[-1, 1, SPP, [512, [5, 9, 13]]],
|
29 |
+
[-1, 1, Conv, [1024, 3, 1]],
|
30 |
+
[-1, 1, Conv, [512, 1, 1]],
|
31 |
+
[-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
|
32 |
+
|
33 |
+
[-2, 1, Conv, [256, 1, 1]],
|
34 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
35 |
+
[[-1, 8], 1, Concat, [1]], # cat backbone P4
|
36 |
+
[-1, 1, Bottleneck, [512, False]],
|
37 |
+
[-1, 1, Bottleneck, [512, False]],
|
38 |
+
[-1, 1, Conv, [256, 1, 1]],
|
39 |
+
[-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
|
40 |
+
|
41 |
+
[-2, 1, Conv, [128, 1, 1]],
|
42 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
43 |
+
[[-1, 6], 1, Concat, [1]], # cat backbone P3
|
44 |
+
[-1, 1, Bottleneck, [256, False]],
|
45 |
+
[-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
|
46 |
+
|
47 |
+
[[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
|
48 |
+
]
|
ultralytics/models/v3/yolov3-tiny.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv3-tiny object detection model with P4-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
depth_multiple: 1.0 # model depth multiple
|
7 |
+
width_multiple: 1.0 # layer channel multiple
|
8 |
+
|
9 |
+
# YOLOv3-tiny backbone
|
10 |
+
backbone:
|
11 |
+
# [from, number, module, args]
|
12 |
+
[[-1, 1, Conv, [16, 3, 1]], # 0
|
13 |
+
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2
|
14 |
+
[-1, 1, Conv, [32, 3, 1]],
|
15 |
+
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4
|
16 |
+
[-1, 1, Conv, [64, 3, 1]],
|
17 |
+
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8
|
18 |
+
[-1, 1, Conv, [128, 3, 1]],
|
19 |
+
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16
|
20 |
+
[-1, 1, Conv, [256, 3, 1]],
|
21 |
+
[-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32
|
22 |
+
[-1, 1, Conv, [512, 3, 1]],
|
23 |
+
[-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11
|
24 |
+
[-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12
|
25 |
+
]
|
26 |
+
|
27 |
+
# YOLOv3-tiny head
|
28 |
+
head:
|
29 |
+
[[-1, 1, Conv, [1024, 3, 1]],
|
30 |
+
[-1, 1, Conv, [256, 1, 1]],
|
31 |
+
[-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large)
|
32 |
+
|
33 |
+
[-2, 1, Conv, [128, 1, 1]],
|
34 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
35 |
+
[[-1, 8], 1, Concat, [1]], # cat backbone P4
|
36 |
+
[-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium)
|
37 |
+
|
38 |
+
[[19, 15], 1, Detect, [nc]], # Detect(P4, P5)
|
39 |
+
]
|
ultralytics/models/v3/yolov3.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv3 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
depth_multiple: 1.0 # model depth multiple
|
7 |
+
width_multiple: 1.0 # layer channel multiple
|
8 |
+
|
9 |
+
# darknet53 backbone
|
10 |
+
backbone:
|
11 |
+
# [from, number, module, args]
|
12 |
+
[[-1, 1, Conv, [32, 3, 1]], # 0
|
13 |
+
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2
|
14 |
+
[-1, 1, Bottleneck, [64]],
|
15 |
+
[-1, 1, Conv, [128, 3, 2]], # 3-P2/4
|
16 |
+
[-1, 2, Bottleneck, [128]],
|
17 |
+
[-1, 1, Conv, [256, 3, 2]], # 5-P3/8
|
18 |
+
[-1, 8, Bottleneck, [256]],
|
19 |
+
[-1, 1, Conv, [512, 3, 2]], # 7-P4/16
|
20 |
+
[-1, 8, Bottleneck, [512]],
|
21 |
+
[-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
|
22 |
+
[-1, 4, Bottleneck, [1024]], # 10
|
23 |
+
]
|
24 |
+
|
25 |
+
# YOLOv3 head
|
26 |
+
head:
|
27 |
+
[[-1, 1, Bottleneck, [1024, False]],
|
28 |
+
[-1, 1, Conv, [512, 1, 1]],
|
29 |
+
[-1, 1, Conv, [1024, 3, 1]],
|
30 |
+
[-1, 1, Conv, [512, 1, 1]],
|
31 |
+
[-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
|
32 |
+
|
33 |
+
[-2, 1, Conv, [256, 1, 1]],
|
34 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
35 |
+
[[-1, 8], 1, Concat, [1]], # cat backbone P4
|
36 |
+
[-1, 1, Bottleneck, [512, False]],
|
37 |
+
[-1, 1, Bottleneck, [512, False]],
|
38 |
+
[-1, 1, Conv, [256, 1, 1]],
|
39 |
+
[-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
|
40 |
+
|
41 |
+
[-2, 1, Conv, [128, 1, 1]],
|
42 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
43 |
+
[[-1, 6], 1, Concat, [1]], # cat backbone P3
|
44 |
+
[-1, 1, Bottleneck, [256, False]],
|
45 |
+
[-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
|
46 |
+
|
47 |
+
[[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
|
48 |
+
]
|
ultralytics/models/v5/yolov5-p6.yaml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv5 object detection model with P3-P6 outputs. For details see https://docs.ultralytics.com/models/yolov5
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov5n-p6.yaml' will call yolov5-p6.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024]
|
9 |
+
s: [0.33, 0.50, 1024]
|
10 |
+
m: [0.67, 0.75, 1024]
|
11 |
+
l: [1.00, 1.00, 1024]
|
12 |
+
x: [1.33, 1.25, 1024]
|
13 |
+
|
14 |
+
# YOLOv5 v6.0 backbone
|
15 |
+
backbone:
|
16 |
+
# [from, number, module, args]
|
17 |
+
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
18 |
+
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
19 |
+
[-1, 3, C3, [128]],
|
20 |
+
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
21 |
+
[-1, 6, C3, [256]],
|
22 |
+
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
23 |
+
[-1, 9, C3, [512]],
|
24 |
+
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
|
25 |
+
[-1, 3, C3, [768]],
|
26 |
+
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
|
27 |
+
[-1, 3, C3, [1024]],
|
28 |
+
[-1, 1, SPPF, [1024, 5]], # 11
|
29 |
+
]
|
30 |
+
|
31 |
+
# YOLOv5 v6.0 head
|
32 |
+
head:
|
33 |
+
[[-1, 1, Conv, [768, 1, 1]],
|
34 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
35 |
+
[[-1, 8], 1, Concat, [1]], # cat backbone P5
|
36 |
+
[-1, 3, C3, [768, False]], # 15
|
37 |
+
|
38 |
+
[-1, 1, Conv, [512, 1, 1]],
|
39 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
40 |
+
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
41 |
+
[-1, 3, C3, [512, False]], # 19
|
42 |
+
|
43 |
+
[-1, 1, Conv, [256, 1, 1]],
|
44 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
45 |
+
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
46 |
+
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
|
47 |
+
|
48 |
+
[-1, 1, Conv, [256, 3, 2]],
|
49 |
+
[[-1, 20], 1, Concat, [1]], # cat head P4
|
50 |
+
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
|
51 |
+
|
52 |
+
[-1, 1, Conv, [512, 3, 2]],
|
53 |
+
[[-1, 16], 1, Concat, [1]], # cat head P5
|
54 |
+
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
|
55 |
+
|
56 |
+
[-1, 1, Conv, [768, 3, 2]],
|
57 |
+
[[-1, 12], 1, Concat, [1]], # cat head P6
|
58 |
+
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
|
59 |
+
|
60 |
+
[[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6)
|
61 |
+
]
|
ultralytics/models/v5/yolov5.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv5 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov5
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov5n.yaml' will call yolov5.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024]
|
9 |
+
s: [0.33, 0.50, 1024]
|
10 |
+
m: [0.67, 0.75, 1024]
|
11 |
+
l: [1.00, 1.00, 1024]
|
12 |
+
x: [1.33, 1.25, 1024]
|
13 |
+
|
14 |
+
# YOLOv5 v6.0 backbone
|
15 |
+
backbone:
|
16 |
+
# [from, number, module, args]
|
17 |
+
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
18 |
+
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
19 |
+
[-1, 3, C3, [128]],
|
20 |
+
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
21 |
+
[-1, 6, C3, [256]],
|
22 |
+
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
23 |
+
[-1, 9, C3, [512]],
|
24 |
+
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
25 |
+
[-1, 3, C3, [1024]],
|
26 |
+
[-1, 1, SPPF, [1024, 5]], # 9
|
27 |
+
]
|
28 |
+
|
29 |
+
# YOLOv5 v6.0 head
|
30 |
+
head:
|
31 |
+
[[-1, 1, Conv, [512, 1, 1]],
|
32 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
33 |
+
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
34 |
+
[-1, 3, C3, [512, False]], # 13
|
35 |
+
|
36 |
+
[-1, 1, Conv, [256, 1, 1]],
|
37 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
38 |
+
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
39 |
+
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
40 |
+
|
41 |
+
[-1, 1, Conv, [256, 3, 2]],
|
42 |
+
[[-1, 14], 1, Concat, [1]], # cat head P4
|
43 |
+
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
44 |
+
|
45 |
+
[-1, 1, Conv, [512, 3, 2]],
|
46 |
+
[[-1, 10], 1, Concat, [1]], # cat head P5
|
47 |
+
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
48 |
+
|
49 |
+
[[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
|
50 |
+
]
|
ultralytics/models/v6/yolov6.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/models/yolov6
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
activation: nn.ReLU() # (optional) model default activation function
|
7 |
+
scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n'
|
8 |
+
# [depth, width, max_channels]
|
9 |
+
n: [0.33, 0.25, 1024]
|
10 |
+
s: [0.33, 0.50, 1024]
|
11 |
+
m: [0.67, 0.75, 768]
|
12 |
+
l: [1.00, 1.00, 512]
|
13 |
+
x: [1.00, 1.25, 512]
|
14 |
+
|
15 |
+
# YOLOv6-3.0s backbone
|
16 |
+
backbone:
|
17 |
+
# [from, repeats, module, args]
|
18 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
19 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
20 |
+
- [-1, 6, Conv, [128, 3, 1]]
|
21 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
22 |
+
- [-1, 12, Conv, [256, 3, 1]]
|
23 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
24 |
+
- [-1, 18, Conv, [512, 3, 1]]
|
25 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
26 |
+
- [-1, 6, Conv, [1024, 3, 1]]
|
27 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
28 |
+
|
29 |
+
# YOLOv6-3.0s head
|
30 |
+
head:
|
31 |
+
- [-1, 1, Conv, [256, 1, 1]]
|
32 |
+
- [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]]
|
33 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
34 |
+
- [-1, 1, Conv, [256, 3, 1]]
|
35 |
+
- [-1, 9, Conv, [256, 3, 1]] # 14
|
36 |
+
|
37 |
+
- [-1, 1, Conv, [128, 1, 1]]
|
38 |
+
- [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]]
|
39 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
40 |
+
- [-1, 1, Conv, [128, 3, 1]]
|
41 |
+
- [-1, 9, Conv, [128, 3, 1]] # 19
|
42 |
+
|
43 |
+
- [-1, 1, Conv, [128, 3, 2]]
|
44 |
+
- [[-1, 15], 1, Concat, [1]] # cat head P4
|
45 |
+
- [-1, 1, Conv, [256, 3, 1]]
|
46 |
+
- [-1, 9, Conv, [256, 3, 1]] # 23
|
47 |
+
|
48 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
49 |
+
- [[-1, 10], 1, Concat, [1]] # cat head P5
|
50 |
+
- [-1, 1, Conv, [512, 3, 1]]
|
51 |
+
- [-1, 9, Conv, [512, 3, 1]] # 27
|
52 |
+
|
53 |
+
- [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5)
|
ultralytics/models/v8/yolov8-cls.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 1000 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024]
|
9 |
+
s: [0.33, 0.50, 1024]
|
10 |
+
m: [0.67, 0.75, 1024]
|
11 |
+
l: [1.00, 1.00, 1024]
|
12 |
+
x: [1.00, 1.25, 1024]
|
13 |
+
|
14 |
+
# YOLOv8.0n backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
|
27 |
+
# YOLOv8.0n head
|
28 |
+
head:
|
29 |
+
- [-1, 1, Classify, [nc]] # Classify
|
ultralytics/models/v8/yolov8-p2.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8 object detection model with P2-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024]
|
9 |
+
s: [0.33, 0.50, 1024]
|
10 |
+
m: [0.67, 0.75, 768]
|
11 |
+
l: [1.00, 1.00, 512]
|
12 |
+
x: [1.00, 1.25, 512]
|
13 |
+
|
14 |
+
# YOLOv8.0 backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
27 |
+
|
28 |
+
# YOLOv8.0-p2 head
|
29 |
+
head:
|
30 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
31 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
32 |
+
- [-1, 3, C2f, [512]] # 12
|
33 |
+
|
34 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
35 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
36 |
+
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
37 |
+
|
38 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
39 |
+
- [[-1, 2], 1, Concat, [1]] # cat backbone P2
|
40 |
+
- [-1, 3, C2f, [128]] # 18 (P2/4-xsmall)
|
41 |
+
|
42 |
+
- [-1, 1, Conv, [128, 3, 2]]
|
43 |
+
- [[-1, 15], 1, Concat, [1]] # cat head P3
|
44 |
+
- [-1, 3, C2f, [256]] # 21 (P3/8-small)
|
45 |
+
|
46 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
47 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
48 |
+
- [-1, 3, C2f, [512]] # 24 (P4/16-medium)
|
49 |
+
|
50 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
51 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
52 |
+
- [-1, 3, C2f, [1024]] # 27 (P5/32-large)
|
53 |
+
|
54 |
+
- [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5)
|
ultralytics/models/v8/yolov8-p6.yaml
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8 object detection model with P3-P6 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024]
|
9 |
+
s: [0.33, 0.50, 1024]
|
10 |
+
m: [0.67, 0.75, 768]
|
11 |
+
l: [1.00, 1.00, 512]
|
12 |
+
x: [1.00, 1.25, 512]
|
13 |
+
|
14 |
+
# YOLOv8.0x6 backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [768, True]]
|
26 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64
|
27 |
+
- [-1, 3, C2f, [1024, True]]
|
28 |
+
- [-1, 1, SPPF, [1024, 5]] # 11
|
29 |
+
|
30 |
+
# YOLOv8.0x6 head
|
31 |
+
head:
|
32 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
33 |
+
- [[-1, 8], 1, Concat, [1]] # cat backbone P5
|
34 |
+
- [-1, 3, C2, [768, False]] # 14
|
35 |
+
|
36 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
37 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
38 |
+
- [-1, 3, C2, [512, False]] # 17
|
39 |
+
|
40 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
41 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
42 |
+
- [-1, 3, C2, [256, False]] # 20 (P3/8-small)
|
43 |
+
|
44 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
45 |
+
- [[-1, 17], 1, Concat, [1]] # cat head P4
|
46 |
+
- [-1, 3, C2, [512, False]] # 23 (P4/16-medium)
|
47 |
+
|
48 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
49 |
+
- [[-1, 14], 1, Concat, [1]] # cat head P5
|
50 |
+
- [-1, 3, C2, [768, False]] # 26 (P5/32-large)
|
51 |
+
|
52 |
+
- [-1, 1, Conv, [768, 3, 2]]
|
53 |
+
- [[-1, 11], 1, Concat, [1]] # cat head P6
|
54 |
+
- [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge)
|
55 |
+
|
56 |
+
- [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6)
|
ultralytics/models/v8/yolov8-pose-p6.yaml
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 1 # number of classes
|
6 |
+
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
7 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n'
|
8 |
+
# [depth, width, max_channels]
|
9 |
+
n: [0.33, 0.25, 1024]
|
10 |
+
s: [0.33, 0.50, 1024]
|
11 |
+
m: [0.67, 0.75, 768]
|
12 |
+
l: [1.00, 1.00, 512]
|
13 |
+
x: [1.00, 1.25, 512]
|
14 |
+
|
15 |
+
# YOLOv8.0x6 backbone
|
16 |
+
backbone:
|
17 |
+
# [from, repeats, module, args]
|
18 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
19 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
20 |
+
- [-1, 3, C2f, [128, True]]
|
21 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
22 |
+
- [-1, 6, C2f, [256, True]]
|
23 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
24 |
+
- [-1, 6, C2f, [512, True]]
|
25 |
+
- [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
|
26 |
+
- [-1, 3, C2f, [768, True]]
|
27 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64
|
28 |
+
- [-1, 3, C2f, [1024, True]]
|
29 |
+
- [-1, 1, SPPF, [1024, 5]] # 11
|
30 |
+
|
31 |
+
# YOLOv8.0x6 head
|
32 |
+
head:
|
33 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
34 |
+
- [[-1, 8], 1, Concat, [1]] # cat backbone P5
|
35 |
+
- [-1, 3, C2, [768, False]] # 14
|
36 |
+
|
37 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
38 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
39 |
+
- [-1, 3, C2, [512, False]] # 17
|
40 |
+
|
41 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
42 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
43 |
+
- [-1, 3, C2, [256, False]] # 20 (P3/8-small)
|
44 |
+
|
45 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
46 |
+
- [[-1, 17], 1, Concat, [1]] # cat head P4
|
47 |
+
- [-1, 3, C2, [512, False]] # 23 (P4/16-medium)
|
48 |
+
|
49 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
50 |
+
- [[-1, 14], 1, Concat, [1]] # cat head P5
|
51 |
+
- [-1, 3, C2, [768, False]] # 26 (P5/32-large)
|
52 |
+
|
53 |
+
- [-1, 1, Conv, [768, 3, 2]]
|
54 |
+
- [[-1, 11], 1, Concat, [1]] # cat head P6
|
55 |
+
- [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge)
|
56 |
+
|
57 |
+
- [[20, 23, 26, 29], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5, P6)
|
ultralytics/models/v8/yolov8-pose.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 1 # number of classes
|
6 |
+
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
7 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n'
|
8 |
+
# [depth, width, max_channels]
|
9 |
+
n: [0.33, 0.25, 1024]
|
10 |
+
s: [0.33, 0.50, 1024]
|
11 |
+
m: [0.67, 0.75, 768]
|
12 |
+
l: [1.00, 1.00, 512]
|
13 |
+
x: [1.00, 1.25, 512]
|
14 |
+
|
15 |
+
# YOLOv8.0n backbone
|
16 |
+
backbone:
|
17 |
+
# [from, repeats, module, args]
|
18 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
19 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
20 |
+
- [-1, 3, C2f, [128, True]]
|
21 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
22 |
+
- [-1, 6, C2f, [256, True]]
|
23 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
24 |
+
- [-1, 6, C2f, [512, True]]
|
25 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
26 |
+
- [-1, 3, C2f, [1024, True]]
|
27 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
28 |
+
|
29 |
+
# YOLOv8.0n head
|
30 |
+
head:
|
31 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
32 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
33 |
+
- [-1, 3, C2f, [512]] # 12
|
34 |
+
|
35 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
36 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
37 |
+
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
38 |
+
|
39 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
40 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
41 |
+
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
|
42 |
+
|
43 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
44 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
45 |
+
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
|
46 |
+
|
47 |
+
- [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5)
|
ultralytics/models/v8/yolov8-rtdetr.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
|
9 |
+
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
|
10 |
+
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
|
11 |
+
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
|
12 |
+
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
|
13 |
+
|
14 |
+
# YOLOv8.0n backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
27 |
+
|
28 |
+
# YOLOv8.0n head
|
29 |
+
head:
|
30 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
31 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
32 |
+
- [-1, 3, C2f, [512]] # 12
|
33 |
+
|
34 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
35 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
36 |
+
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
37 |
+
|
38 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
39 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
40 |
+
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
|
41 |
+
|
42 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
43 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
44 |
+
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
|
45 |
+
|
46 |
+
- [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
|
ultralytics/models/v8/yolov8-seg.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024]
|
9 |
+
s: [0.33, 0.50, 1024]
|
10 |
+
m: [0.67, 0.75, 768]
|
11 |
+
l: [1.00, 1.00, 512]
|
12 |
+
x: [1.00, 1.25, 512]
|
13 |
+
|
14 |
+
# YOLOv8.0n backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
27 |
+
|
28 |
+
# YOLOv8.0n head
|
29 |
+
head:
|
30 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
31 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
32 |
+
- [-1, 3, C2f, [512]] # 12
|
33 |
+
|
34 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
35 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
36 |
+
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
37 |
+
|
38 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
39 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
40 |
+
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
|
41 |
+
|
42 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
43 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
44 |
+
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
|
45 |
+
|
46 |
+
- [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5)
|
ultralytics/models/v8/yolov8.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
3 |
+
|
4 |
+
# Parameters
|
5 |
+
nc: 80 # number of classes
|
6 |
+
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
7 |
+
# [depth, width, max_channels]
|
8 |
+
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
|
9 |
+
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
|
10 |
+
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
|
11 |
+
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
|
12 |
+
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
|
13 |
+
|
14 |
+
# YOLOv8.0n backbone
|
15 |
+
backbone:
|
16 |
+
# [from, repeats, module, args]
|
17 |
+
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
18 |
+
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
19 |
+
- [-1, 3, C2f, [128, True]]
|
20 |
+
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
21 |
+
- [-1, 6, C2f, [256, True]]
|
22 |
+
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
23 |
+
- [-1, 6, C2f, [512, True]]
|
24 |
+
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
25 |
+
- [-1, 3, C2f, [1024, True]]
|
26 |
+
- [-1, 1, SPPF, [1024, 5]] # 9
|
27 |
+
|
28 |
+
# YOLOv8.0n head
|
29 |
+
head:
|
30 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
31 |
+
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
32 |
+
- [-1, 3, C2f, [512]] # 12
|
33 |
+
|
34 |
+
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
35 |
+
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
36 |
+
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
37 |
+
|
38 |
+
- [-1, 1, Conv, [256, 3, 2]]
|
39 |
+
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
40 |
+
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
|
41 |
+
|
42 |
+
- [-1, 1, Conv, [512, 3, 2]]
|
43 |
+
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
44 |
+
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
|
45 |
+
|
46 |
+
- [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
|
ultralytics/nn/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
|
4 |
+
attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load,
|
5 |
+
yaml_model_load)
|
6 |
+
|
7 |
+
__all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task',
|
8 |
+
'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel',
|
9 |
+
'BaseModel')
|
ultralytics/nn/autobackend.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import ast
|
4 |
+
import contextlib
|
5 |
+
import json
|
6 |
+
import platform
|
7 |
+
import zipfile
|
8 |
+
from collections import OrderedDict, namedtuple
|
9 |
+
from pathlib import Path
|
10 |
+
from urllib.parse import urlparse
|
11 |
+
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
from ultralytics.yolo.utils import LINUX, LOGGER, ROOT, yaml_load
|
19 |
+
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version, check_yaml
|
20 |
+
from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
|
21 |
+
from ultralytics.yolo.utils.ops import xywh2xyxy
|
22 |
+
|
23 |
+
|
24 |
+
def check_class_names(names):
|
25 |
+
"""Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts."""
|
26 |
+
if isinstance(names, list): # names is a list
|
27 |
+
names = dict(enumerate(names)) # convert to dict
|
28 |
+
if isinstance(names, dict):
|
29 |
+
# Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'
|
30 |
+
names = {int(k): str(v) for k, v in names.items()}
|
31 |
+
n = len(names)
|
32 |
+
if max(names.keys()) >= n:
|
33 |
+
raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
|
34 |
+
f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
|
35 |
+
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
|
36 |
+
map = yaml_load(ROOT / 'datasets/ImageNet.yaml')['map'] # human-readable names
|
37 |
+
names = {k: map[v] for k, v in names.items()}
|
38 |
+
return names
|
39 |
+
|
40 |
+
|
41 |
+
class AutoBackend(nn.Module):
|
42 |
+
|
43 |
+
def __init__(self,
|
44 |
+
weights='yolov8n.pt',
|
45 |
+
device=torch.device('cpu'),
|
46 |
+
dnn=False,
|
47 |
+
data=None,
|
48 |
+
fp16=False,
|
49 |
+
fuse=True,
|
50 |
+
verbose=True):
|
51 |
+
"""
|
52 |
+
MultiBackend class for python inference on various platforms using Ultralytics YOLO.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
weights (str): The path to the weights file. Default: 'yolov8n.pt'
|
56 |
+
device (torch.device): The device to run the model on.
|
57 |
+
dnn (bool): Use OpenCV DNN module for inference if True, defaults to False.
|
58 |
+
data (str | Path | optional): Additional data.yaml file for class names.
|
59 |
+
fp16 (bool): If True, use half precision. Default: False
|
60 |
+
fuse (bool): Whether to fuse the model or not. Default: True
|
61 |
+
verbose (bool): Whether to run in verbose mode or not. Default: True
|
62 |
+
|
63 |
+
Supported formats and their naming conventions:
|
64 |
+
| Format | Suffix |
|
65 |
+
|-----------------------|------------------|
|
66 |
+
| PyTorch | *.pt |
|
67 |
+
| TorchScript | *.torchscript |
|
68 |
+
| ONNX Runtime | *.onnx |
|
69 |
+
| ONNX OpenCV DNN | *.onnx dnn=True |
|
70 |
+
| OpenVINO | *.xml |
|
71 |
+
| CoreML | *.mlmodel |
|
72 |
+
| TensorRT | *.engine |
|
73 |
+
| TensorFlow SavedModel | *_saved_model |
|
74 |
+
| TensorFlow GraphDef | *.pb |
|
75 |
+
| TensorFlow Lite | *.tflite |
|
76 |
+
| TensorFlow Edge TPU | *_edgetpu.tflite |
|
77 |
+
| PaddlePaddle | *_paddle_model |
|
78 |
+
"""
|
79 |
+
super().__init__()
|
80 |
+
w = str(weights[0] if isinstance(weights, list) else weights)
|
81 |
+
nn_module = isinstance(weights, torch.nn.Module)
|
82 |
+
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
|
83 |
+
fp16 &= pt or jit or onnx or engine or nn_module or triton # FP16
|
84 |
+
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
|
85 |
+
stride = 32 # default stride
|
86 |
+
model, metadata = None, None
|
87 |
+
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
88 |
+
if not (pt or triton or nn_module):
|
89 |
+
w = attempt_download_asset(w) # download if not local
|
90 |
+
|
91 |
+
# NOTE: special case: in-memory pytorch model
|
92 |
+
if nn_module:
|
93 |
+
model = weights.to(device)
|
94 |
+
model = model.fuse(verbose=verbose) if fuse else model
|
95 |
+
if hasattr(model, 'kpt_shape'):
|
96 |
+
kpt_shape = model.kpt_shape # pose-only
|
97 |
+
stride = max(int(model.stride.max()), 32) # model stride
|
98 |
+
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
99 |
+
model.half() if fp16 else model.float()
|
100 |
+
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
101 |
+
pt = True
|
102 |
+
elif pt: # PyTorch
|
103 |
+
from ultralytics.nn.tasks import attempt_load_weights
|
104 |
+
model = attempt_load_weights(weights if isinstance(weights, list) else w,
|
105 |
+
device=device,
|
106 |
+
inplace=True,
|
107 |
+
fuse=fuse)
|
108 |
+
if hasattr(model, 'kpt_shape'):
|
109 |
+
kpt_shape = model.kpt_shape # pose-only
|
110 |
+
stride = max(int(model.stride.max()), 32) # model stride
|
111 |
+
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
112 |
+
model.half() if fp16 else model.float()
|
113 |
+
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
114 |
+
elif jit: # TorchScript
|
115 |
+
LOGGER.info(f'Loading {w} for TorchScript inference...')
|
116 |
+
extra_files = {'config.txt': ''} # model metadata
|
117 |
+
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
|
118 |
+
model.half() if fp16 else model.float()
|
119 |
+
if extra_files['config.txt']: # load metadata dict
|
120 |
+
metadata = json.loads(extra_files['config.txt'], object_hook=lambda x: dict(x.items()))
|
121 |
+
elif dnn: # ONNX OpenCV DNN
|
122 |
+
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
|
123 |
+
check_requirements('opencv-python>=4.5.4')
|
124 |
+
net = cv2.dnn.readNetFromONNX(w)
|
125 |
+
elif onnx: # ONNX Runtime
|
126 |
+
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
127 |
+
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
128 |
+
import onnxruntime
|
129 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
130 |
+
session = onnxruntime.InferenceSession(w, providers=providers)
|
131 |
+
output_names = [x.name for x in session.get_outputs()]
|
132 |
+
metadata = session.get_modelmeta().custom_metadata_map # metadata
|
133 |
+
elif xml: # OpenVINO
|
134 |
+
LOGGER.info(f'Loading {w} for OpenVINO inference...')
|
135 |
+
check_requirements('openvino') # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
136 |
+
from openvino.runtime import Core, Layout, get_batch # noqa
|
137 |
+
ie = Core()
|
138 |
+
w = Path(w)
|
139 |
+
if not w.is_file(): # if not *.xml
|
140 |
+
w = next(w.glob('*.xml')) # get *.xml file from *_openvino_model dir
|
141 |
+
network = ie.read_model(model=str(w), weights=w.with_suffix('.bin'))
|
142 |
+
if network.get_parameters()[0].get_layout().empty:
|
143 |
+
network.get_parameters()[0].set_layout(Layout('NCHW'))
|
144 |
+
batch_dim = get_batch(network)
|
145 |
+
if batch_dim.is_static:
|
146 |
+
batch_size = batch_dim.get_length()
|
147 |
+
executable_network = ie.compile_model(network, device_name='CPU') # device_name="MYRIAD" for NCS2
|
148 |
+
metadata = w.parent / 'metadata.yaml'
|
149 |
+
elif engine: # TensorRT
|
150 |
+
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
151 |
+
try:
|
152 |
+
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
|
153 |
+
except ImportError:
|
154 |
+
if LINUX:
|
155 |
+
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
|
156 |
+
import tensorrt as trt # noqa
|
157 |
+
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
158 |
+
if device.type == 'cpu':
|
159 |
+
device = torch.device('cuda:0')
|
160 |
+
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
161 |
+
logger = trt.Logger(trt.Logger.INFO)
|
162 |
+
# Read file
|
163 |
+
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
164 |
+
meta_len = int.from_bytes(f.read(4), byteorder='little') # read metadata length
|
165 |
+
metadata = json.loads(f.read(meta_len).decode('utf-8')) # read metadata
|
166 |
+
model = runtime.deserialize_cuda_engine(f.read()) # read engine
|
167 |
+
context = model.create_execution_context()
|
168 |
+
bindings = OrderedDict()
|
169 |
+
output_names = []
|
170 |
+
fp16 = False # default updated below
|
171 |
+
dynamic = False
|
172 |
+
for i in range(model.num_bindings):
|
173 |
+
name = model.get_binding_name(i)
|
174 |
+
dtype = trt.nptype(model.get_binding_dtype(i))
|
175 |
+
if model.binding_is_input(i):
|
176 |
+
if -1 in tuple(model.get_binding_shape(i)): # dynamic
|
177 |
+
dynamic = True
|
178 |
+
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
|
179 |
+
if dtype == np.float16:
|
180 |
+
fp16 = True
|
181 |
+
else: # output
|
182 |
+
output_names.append(name)
|
183 |
+
shape = tuple(context.get_binding_shape(i))
|
184 |
+
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
185 |
+
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
186 |
+
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
187 |
+
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
|
188 |
+
elif coreml: # CoreML
|
189 |
+
LOGGER.info(f'Loading {w} for CoreML inference...')
|
190 |
+
import coremltools as ct
|
191 |
+
model = ct.models.MLModel(w)
|
192 |
+
metadata = dict(model.user_defined_metadata)
|
193 |
+
elif saved_model: # TF SavedModel
|
194 |
+
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
|
195 |
+
import tensorflow as tf
|
196 |
+
keras = False # assume TF1 saved_model
|
197 |
+
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
198 |
+
metadata = Path(w) / 'metadata.yaml'
|
199 |
+
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
200 |
+
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
201 |
+
import tensorflow as tf
|
202 |
+
|
203 |
+
from ultralytics.yolo.engine.exporter import gd_outputs
|
204 |
+
|
205 |
+
def wrap_frozen_graph(gd, inputs, outputs):
|
206 |
+
"""Wrap frozen graphs for deployment."""
|
207 |
+
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
|
208 |
+
ge = x.graph.as_graph_element
|
209 |
+
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
210 |
+
|
211 |
+
gd = tf.Graph().as_graph_def() # TF GraphDef
|
212 |
+
with open(w, 'rb') as f:
|
213 |
+
gd.ParseFromString(f.read())
|
214 |
+
frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
|
215 |
+
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
216 |
+
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
217 |
+
from tflite_runtime.interpreter import Interpreter, load_delegate
|
218 |
+
except ImportError:
|
219 |
+
import tensorflow as tf
|
220 |
+
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
|
221 |
+
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
|
222 |
+
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
223 |
+
delegate = {
|
224 |
+
'Linux': 'libedgetpu.so.1',
|
225 |
+
'Darwin': 'libedgetpu.1.dylib',
|
226 |
+
'Windows': 'edgetpu.dll'}[platform.system()]
|
227 |
+
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
|
228 |
+
else: # TFLite
|
229 |
+
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
|
230 |
+
interpreter = Interpreter(model_path=w) # load TFLite model
|
231 |
+
interpreter.allocate_tensors() # allocate
|
232 |
+
input_details = interpreter.get_input_details() # inputs
|
233 |
+
output_details = interpreter.get_output_details() # outputs
|
234 |
+
# Load metadata
|
235 |
+
with contextlib.suppress(zipfile.BadZipFile):
|
236 |
+
with zipfile.ZipFile(w, 'r') as model:
|
237 |
+
meta_file = model.namelist()[0]
|
238 |
+
metadata = ast.literal_eval(model.read(meta_file).decode('utf-8'))
|
239 |
+
elif tfjs: # TF.js
|
240 |
+
raise NotImplementedError('YOLOv8 TF.js inference is not supported')
|
241 |
+
elif paddle: # PaddlePaddle
|
242 |
+
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
|
243 |
+
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
|
244 |
+
import paddle.inference as pdi # noqa
|
245 |
+
w = Path(w)
|
246 |
+
if not w.is_file(): # if not *.pdmodel
|
247 |
+
w = next(w.rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
|
248 |
+
config = pdi.Config(str(w), str(w.with_suffix('.pdiparams')))
|
249 |
+
if cuda:
|
250 |
+
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
|
251 |
+
predictor = pdi.create_predictor(config)
|
252 |
+
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
253 |
+
output_names = predictor.get_output_names()
|
254 |
+
metadata = w.parents[1] / 'metadata.yaml'
|
255 |
+
elif triton: # NVIDIA Triton Inference Server
|
256 |
+
LOGGER.info('Triton Inference Server not supported...')
|
257 |
+
'''
|
258 |
+
TODO:
|
259 |
+
check_requirements('tritonclient[all]')
|
260 |
+
from utils.triton import TritonRemoteModel
|
261 |
+
model = TritonRemoteModel(url=w)
|
262 |
+
nhwc = model.runtime.startswith("tensorflow")
|
263 |
+
'''
|
264 |
+
else:
|
265 |
+
from ultralytics.yolo.engine.exporter import export_formats
|
266 |
+
raise TypeError(f"model='{w}' is not a supported model format. "
|
267 |
+
'See https://docs.ultralytics.com/modes/predict for help.'
|
268 |
+
f'\n\n{export_formats()}')
|
269 |
+
|
270 |
+
# Load external metadata YAML
|
271 |
+
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
|
272 |
+
metadata = yaml_load(metadata)
|
273 |
+
if metadata:
|
274 |
+
for k, v in metadata.items():
|
275 |
+
if k in ('stride', 'batch'):
|
276 |
+
metadata[k] = int(v)
|
277 |
+
elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str):
|
278 |
+
metadata[k] = eval(v)
|
279 |
+
stride = metadata['stride']
|
280 |
+
task = metadata['task']
|
281 |
+
batch = metadata['batch']
|
282 |
+
imgsz = metadata['imgsz']
|
283 |
+
names = metadata['names']
|
284 |
+
kpt_shape = metadata.get('kpt_shape')
|
285 |
+
elif not (pt or triton or nn_module):
|
286 |
+
LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
|
287 |
+
|
288 |
+
# Check names
|
289 |
+
if 'names' not in locals(): # names missing
|
290 |
+
names = self._apply_default_class_names(data)
|
291 |
+
names = check_class_names(names)
|
292 |
+
|
293 |
+
self.__dict__.update(locals()) # assign all variables to self
|
294 |
+
|
295 |
+
def forward(self, im, augment=False, visualize=False):
|
296 |
+
"""
|
297 |
+
Runs inference on the YOLOv8 MultiBackend model.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
im (torch.Tensor): The image tensor to perform inference on.
|
301 |
+
augment (bool): whether to perform data augmentation during inference, defaults to False
|
302 |
+
visualize (bool): whether to visualize the output predictions, defaults to False
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
(tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
|
306 |
+
"""
|
307 |
+
b, ch, h, w = im.shape # batch, channel, height, width
|
308 |
+
if self.fp16 and im.dtype != torch.float16:
|
309 |
+
im = im.half() # to FP16
|
310 |
+
if self.nhwc:
|
311 |
+
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
|
312 |
+
|
313 |
+
if self.pt or self.nn_module: # PyTorch
|
314 |
+
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
|
315 |
+
elif self.jit: # TorchScript
|
316 |
+
y = self.model(im)
|
317 |
+
elif self.dnn: # ONNX OpenCV DNN
|
318 |
+
im = im.cpu().numpy() # torch to numpy
|
319 |
+
self.net.setInput(im)
|
320 |
+
y = self.net.forward()
|
321 |
+
elif self.onnx: # ONNX Runtime
|
322 |
+
im = im.cpu().numpy() # torch to numpy
|
323 |
+
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
324 |
+
elif self.xml: # OpenVINO
|
325 |
+
im = im.cpu().numpy() # FP32
|
326 |
+
y = list(self.executable_network([im]).values())
|
327 |
+
elif self.engine: # TensorRT
|
328 |
+
if self.dynamic and im.shape != self.bindings['images'].shape:
|
329 |
+
i = self.model.get_binding_index('images')
|
330 |
+
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
|
331 |
+
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
|
332 |
+
for name in self.output_names:
|
333 |
+
i = self.model.get_binding_index(name)
|
334 |
+
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
|
335 |
+
s = self.bindings['images'].shape
|
336 |
+
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
|
337 |
+
self.binding_addrs['images'] = int(im.data_ptr())
|
338 |
+
self.context.execute_v2(list(self.binding_addrs.values()))
|
339 |
+
y = [self.bindings[x].data for x in sorted(self.output_names)]
|
340 |
+
elif self.coreml: # CoreML
|
341 |
+
im = im[0].cpu().numpy()
|
342 |
+
im_pil = Image.fromarray((im * 255).astype('uint8'))
|
343 |
+
# im = im.resize((192, 320), Image.ANTIALIAS)
|
344 |
+
y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized
|
345 |
+
if 'confidence' in y:
|
346 |
+
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
|
347 |
+
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
|
348 |
+
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
|
349 |
+
elif len(y) == 1: # classification model
|
350 |
+
y = list(y.values())
|
351 |
+
elif len(y) == 2: # segmentation model
|
352 |
+
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
|
353 |
+
elif self.paddle: # PaddlePaddle
|
354 |
+
im = im.cpu().numpy().astype(np.float32)
|
355 |
+
self.input_handle.copy_from_cpu(im)
|
356 |
+
self.predictor.run()
|
357 |
+
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
|
358 |
+
elif self.triton: # NVIDIA Triton Inference Server
|
359 |
+
y = self.model(im)
|
360 |
+
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
361 |
+
im = im.cpu().numpy()
|
362 |
+
if self.saved_model: # SavedModel
|
363 |
+
y = self.model(im, training=False) if self.keras else self.model(im)
|
364 |
+
if not isinstance(y, list):
|
365 |
+
y = [y]
|
366 |
+
elif self.pb: # GraphDef
|
367 |
+
y = self.frozen_func(x=self.tf.constant(im))
|
368 |
+
if len(y) == 2 and len(self.names) == 999: # segments and names not defined
|
369 |
+
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
|
370 |
+
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
|
371 |
+
self.names = {i: f'class{i}' for i in range(nc)}
|
372 |
+
else: # Lite or Edge TPU
|
373 |
+
input = self.input_details[0]
|
374 |
+
int8 = input['dtype'] == np.int8 # is TFLite quantized int8 model
|
375 |
+
if int8:
|
376 |
+
scale, zero_point = input['quantization']
|
377 |
+
im = (im / scale + zero_point).astype(np.int8) # de-scale
|
378 |
+
self.interpreter.set_tensor(input['index'], im)
|
379 |
+
self.interpreter.invoke()
|
380 |
+
y = []
|
381 |
+
for output in self.output_details:
|
382 |
+
x = self.interpreter.get_tensor(output['index'])
|
383 |
+
if int8:
|
384 |
+
scale, zero_point = output['quantization']
|
385 |
+
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
386 |
+
y.append(x)
|
387 |
+
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
|
388 |
+
if len(y) == 2: # segment with (det, proto) output order reversed
|
389 |
+
if len(y[1].shape) != 4:
|
390 |
+
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
|
391 |
+
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
|
392 |
+
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
|
393 |
+
# y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
|
394 |
+
|
395 |
+
# for x in y:
|
396 |
+
# print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes
|
397 |
+
if isinstance(y, (list, tuple)):
|
398 |
+
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
|
399 |
+
else:
|
400 |
+
return self.from_numpy(y)
|
401 |
+
|
402 |
+
def from_numpy(self, x):
|
403 |
+
"""
|
404 |
+
Convert a numpy array to a tensor.
|
405 |
+
|
406 |
+
Args:
|
407 |
+
x (np.ndarray): The array to be converted.
|
408 |
+
|
409 |
+
Returns:
|
410 |
+
(torch.Tensor): The converted tensor
|
411 |
+
"""
|
412 |
+
return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
|
413 |
+
|
414 |
+
def warmup(self, imgsz=(1, 3, 640, 640)):
|
415 |
+
"""
|
416 |
+
Warm up the model by running one forward pass with a dummy input.
|
417 |
+
|
418 |
+
Args:
|
419 |
+
imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
|
420 |
+
|
421 |
+
Returns:
|
422 |
+
(None): This method runs the forward pass and don't return any value
|
423 |
+
"""
|
424 |
+
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
|
425 |
+
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
|
426 |
+
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
427 |
+
for _ in range(2 if self.jit else 1): #
|
428 |
+
self.forward(im) # warmup
|
429 |
+
|
430 |
+
@staticmethod
|
431 |
+
def _apply_default_class_names(data):
|
432 |
+
"""Applies default class names to an input YAML file or returns numerical class names."""
|
433 |
+
with contextlib.suppress(Exception):
|
434 |
+
return yaml_load(check_yaml(data))['names']
|
435 |
+
return {i: f'class{i}' for i in range(999)} # return default if above errors
|
436 |
+
|
437 |
+
@staticmethod
|
438 |
+
def _model_type(p='path/to/model.pt'):
|
439 |
+
"""
|
440 |
+
This function takes a path to a model file and returns the model type
|
441 |
+
|
442 |
+
Args:
|
443 |
+
p: path to the model file. Defaults to path/to/model.pt
|
444 |
+
"""
|
445 |
+
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
446 |
+
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
|
447 |
+
from ultralytics.yolo.engine.exporter import export_formats
|
448 |
+
sf = list(export_formats().Suffix) # export suffixes
|
449 |
+
if not is_url(p, check=False) and not isinstance(p, str):
|
450 |
+
check_suffix(p, sf) # checks
|
451 |
+
url = urlparse(p) # if url may be Triton inference server
|
452 |
+
types = [s in Path(p).name for s in sf]
|
453 |
+
types[8] &= not types[9] # tflite &= not edgetpu
|
454 |
+
triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
|
455 |
+
return types + [triton]
|
ultralytics/nn/autoshape.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Common modules
|
4 |
+
"""
|
5 |
+
|
6 |
+
from copy import copy
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import requests
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from PIL import Image, ImageOps
|
15 |
+
from torch.cuda import amp
|
16 |
+
|
17 |
+
from ultralytics.nn.autobackend import AutoBackend
|
18 |
+
from ultralytics.yolo.data.augment import LetterBox
|
19 |
+
from ultralytics.yolo.utils import LOGGER, colorstr
|
20 |
+
from ultralytics.yolo.utils.files import increment_path
|
21 |
+
from ultralytics.yolo.utils.ops import Profile, make_divisible, non_max_suppression, scale_boxes, xyxy2xywh
|
22 |
+
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
23 |
+
from ultralytics.yolo.utils.torch_utils import copy_attr, smart_inference_mode
|
24 |
+
|
25 |
+
|
26 |
+
class AutoShape(nn.Module):
|
27 |
+
"""YOLOv8 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS."""
|
28 |
+
conf = 0.25 # NMS confidence threshold
|
29 |
+
iou = 0.45 # NMS IoU threshold
|
30 |
+
agnostic = False # NMS class-agnostic
|
31 |
+
multi_label = False # NMS multiple labels per box
|
32 |
+
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
|
33 |
+
max_det = 1000 # maximum number of detections per image
|
34 |
+
amp = False # Automatic Mixed Precision (AMP) inference
|
35 |
+
|
36 |
+
def __init__(self, model, verbose=True):
|
37 |
+
"""Initializes object and copies attributes from model object."""
|
38 |
+
super().__init__()
|
39 |
+
if verbose:
|
40 |
+
LOGGER.info('Adding AutoShape... ')
|
41 |
+
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
|
42 |
+
self.dmb = isinstance(model, AutoBackend) # DetectMultiBackend() instance
|
43 |
+
self.pt = not self.dmb or model.pt # PyTorch model
|
44 |
+
self.model = model.eval()
|
45 |
+
if self.pt:
|
46 |
+
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
47 |
+
m.inplace = False # Detect.inplace=False for safe multithread inference
|
48 |
+
m.export = True # do not output loss values
|
49 |
+
|
50 |
+
def _apply(self, fn):
|
51 |
+
"""Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers."""
|
52 |
+
self = super()._apply(fn)
|
53 |
+
if self.pt:
|
54 |
+
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
55 |
+
m.stride = fn(m.stride)
|
56 |
+
m.grid = list(map(fn, m.grid))
|
57 |
+
if isinstance(m.anchor_grid, list):
|
58 |
+
m.anchor_grid = list(map(fn, m.anchor_grid))
|
59 |
+
return self
|
60 |
+
|
61 |
+
@smart_inference_mode()
|
62 |
+
def forward(self, ims, size=640, augment=False, profile=False):
|
63 |
+
"""Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:."""
|
64 |
+
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
|
65 |
+
# URI: = 'https://ultralytics.com/images/zidane.jpg'
|
66 |
+
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
67 |
+
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
|
68 |
+
# numpy: = np.zeros((640,1280,3)) # HWC
|
69 |
+
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
|
70 |
+
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
71 |
+
|
72 |
+
dt = (Profile(), Profile(), Profile())
|
73 |
+
with dt[0]:
|
74 |
+
if isinstance(size, int): # expand
|
75 |
+
size = (size, size)
|
76 |
+
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
|
77 |
+
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
78 |
+
if isinstance(ims, torch.Tensor): # torch
|
79 |
+
with amp.autocast(autocast):
|
80 |
+
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
|
81 |
+
|
82 |
+
# Preprocess
|
83 |
+
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
|
84 |
+
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
|
85 |
+
for i, im in enumerate(ims):
|
86 |
+
f = f'image{i}' # filename
|
87 |
+
if isinstance(im, (str, Path)): # filename or uri
|
88 |
+
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
|
89 |
+
im = np.asarray(ImageOps.exif_transpose(im))
|
90 |
+
elif isinstance(im, Image.Image): # PIL Image
|
91 |
+
im, f = np.asarray(ImageOps.exif_transpose(im)), getattr(im, 'filename', f) or f
|
92 |
+
files.append(Path(f).with_suffix('.jpg').name)
|
93 |
+
if im.shape[0] < 5: # image in CHW
|
94 |
+
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
|
95 |
+
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
|
96 |
+
s = im.shape[:2] # HWC
|
97 |
+
shape0.append(s) # image shape
|
98 |
+
g = max(size) / max(s) # gain
|
99 |
+
shape1.append([y * g for y in s])
|
100 |
+
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
101 |
+
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
|
102 |
+
x = [LetterBox(shape1, auto=False)(image=im)['img'] for im in ims] # pad
|
103 |
+
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
104 |
+
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
105 |
+
|
106 |
+
with amp.autocast(autocast):
|
107 |
+
# Inference
|
108 |
+
with dt[1]:
|
109 |
+
y = self.model(x, augment=augment) # forward
|
110 |
+
|
111 |
+
# Postprocess
|
112 |
+
with dt[2]:
|
113 |
+
y = non_max_suppression(y if self.dmb else y[0],
|
114 |
+
self.conf,
|
115 |
+
self.iou,
|
116 |
+
self.classes,
|
117 |
+
self.agnostic,
|
118 |
+
self.multi_label,
|
119 |
+
max_det=self.max_det) # NMS
|
120 |
+
for i in range(n):
|
121 |
+
scale_boxes(shape1, y[i][:, :4], shape0[i])
|
122 |
+
|
123 |
+
return Detections(ims, y, files, dt, self.names, x.shape)
|
124 |
+
|
125 |
+
|
126 |
+
class Detections:
|
127 |
+
""" YOLOv8 detections class for inference results"""
|
128 |
+
|
129 |
+
def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
|
130 |
+
"""Initialize object attributes for YOLO detection results."""
|
131 |
+
super().__init__()
|
132 |
+
d = pred[0].device # device
|
133 |
+
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
|
134 |
+
self.ims = ims # list of images as numpy arrays
|
135 |
+
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
136 |
+
self.names = names # class names
|
137 |
+
self.files = files # image filenames
|
138 |
+
self.times = times # profiling times
|
139 |
+
self.xyxy = pred # xyxy pixels
|
140 |
+
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
|
141 |
+
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
142 |
+
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
143 |
+
self.n = len(self.pred) # number of images (batch size)
|
144 |
+
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
|
145 |
+
self.s = tuple(shape) # inference BCHW shape
|
146 |
+
|
147 |
+
def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
|
148 |
+
"""Return performance metrics and optionally cropped/save images or results."""
|
149 |
+
s, crops = '', []
|
150 |
+
for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
|
151 |
+
s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
|
152 |
+
if pred.shape[0]:
|
153 |
+
for c in pred[:, -1].unique():
|
154 |
+
n = (pred[:, -1] == c).sum() # detections per class
|
155 |
+
s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
|
156 |
+
s = s.rstrip(', ')
|
157 |
+
if show or save or render or crop:
|
158 |
+
annotator = Annotator(im, example=str(self.names))
|
159 |
+
for *box, conf, cls in reversed(pred): # xyxy, confidence, class
|
160 |
+
label = f'{self.names[int(cls)]} {conf:.2f}'
|
161 |
+
if crop:
|
162 |
+
file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
|
163 |
+
crops.append({
|
164 |
+
'box': box,
|
165 |
+
'conf': conf,
|
166 |
+
'cls': cls,
|
167 |
+
'label': label,
|
168 |
+
'im': save_one_box(box, im, file=file, save=save)})
|
169 |
+
else: # all others
|
170 |
+
annotator.box_label(box, label if labels else '', color=colors(cls))
|
171 |
+
im = annotator.im
|
172 |
+
else:
|
173 |
+
s += '(no detections)'
|
174 |
+
|
175 |
+
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
|
176 |
+
if show:
|
177 |
+
im.show(self.files[i]) # show
|
178 |
+
if save:
|
179 |
+
f = self.files[i]
|
180 |
+
im.save(save_dir / f) # save
|
181 |
+
if i == self.n - 1:
|
182 |
+
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
|
183 |
+
if render:
|
184 |
+
self.ims[i] = np.asarray(im)
|
185 |
+
if pprint:
|
186 |
+
s = s.lstrip('\n')
|
187 |
+
return f'{s}\nSpeed: %.1fms preprocess, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
|
188 |
+
if crop:
|
189 |
+
if save:
|
190 |
+
LOGGER.info(f'Saved results to {save_dir}\n')
|
191 |
+
return crops
|
192 |
+
|
193 |
+
def show(self, labels=True):
|
194 |
+
"""Displays YOLO results with detected bounding boxes."""
|
195 |
+
self._run(show=True, labels=labels) # show results
|
196 |
+
|
197 |
+
def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
|
198 |
+
"""Save detection results with optional labels to specified directory."""
|
199 |
+
save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
|
200 |
+
self._run(save=True, labels=labels, save_dir=save_dir) # save results
|
201 |
+
|
202 |
+
def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
|
203 |
+
"""Crops images into detections and saves them if 'save' is True."""
|
204 |
+
save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
|
205 |
+
return self._run(crop=True, save=save, save_dir=save_dir) # crop results
|
206 |
+
|
207 |
+
def render(self, labels=True):
|
208 |
+
"""Renders detected objects and returns images."""
|
209 |
+
self._run(render=True, labels=labels) # render results
|
210 |
+
return self.ims
|
211 |
+
|
212 |
+
def pandas(self):
|
213 |
+
"""Return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])."""
|
214 |
+
import pandas
|
215 |
+
new = copy(self) # return copy
|
216 |
+
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
|
217 |
+
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
|
218 |
+
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
|
219 |
+
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
|
220 |
+
setattr(new, k, [pandas.DataFrame(x, columns=c) for x in a])
|
221 |
+
return new
|
222 |
+
|
223 |
+
def tolist(self):
|
224 |
+
"""Return a list of Detections objects, i.e. 'for result in results.tolist():'."""
|
225 |
+
r = range(self.n) # iterable
|
226 |
+
x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
|
227 |
+
# for d in x:
|
228 |
+
# for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
229 |
+
# setattr(d, k, getattr(d, k)[0]) # pop out of list
|
230 |
+
return x
|
231 |
+
|
232 |
+
def print(self):
|
233 |
+
"""Print the results of the `self._run()` function."""
|
234 |
+
LOGGER.info(self.__str__())
|
235 |
+
|
236 |
+
def __len__(self): # override len(results)
|
237 |
+
return self.n
|
238 |
+
|
239 |
+
def __str__(self): # override print(results)
|
240 |
+
return self._run(pprint=True) # print results
|
241 |
+
|
242 |
+
def __repr__(self):
|
243 |
+
"""Returns a printable representation of the object."""
|
244 |
+
return f'YOLOv8 {self.__class__} instance\n' + self.__str__()
|
ultralytics/nn/modules/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Ultralytics modules. Visualize with:
|
4 |
+
|
5 |
+
from ultralytics.nn.modules import *
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
|
9 |
+
x = torch.ones(1, 128, 40, 40)
|
10 |
+
m = Conv(128, 128)
|
11 |
+
f = f'{m._get_name()}.onnx'
|
12 |
+
torch.onnx.export(m, x, f)
|
13 |
+
os.system(f'onnxsim {f} {f} && open {f}')
|
14 |
+
"""
|
15 |
+
|
16 |
+
from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck,
|
17 |
+
HGBlock, HGStem, Proto, RepC3)
|
18 |
+
from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
|
19 |
+
GhostConv, LightConv, RepConv, SpatialAttention)
|
20 |
+
from .head import Classify, Detect, Pose, RTDETRDecoder, Segment
|
21 |
+
from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
|
22 |
+
MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
|
23 |
+
|
24 |
+
__all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus',
|
25 |
+
'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer',
|
26 |
+
'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3',
|
27 |
+
'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect',
|
28 |
+
'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI',
|
29 |
+
'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP')
|
ultralytics/nn/modules/block.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Block modules
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv
|
11 |
+
from .transformer import TransformerBlock
|
12 |
+
|
13 |
+
__all__ = ('DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost',
|
14 |
+
'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'RepC3')
|
15 |
+
|
16 |
+
|
17 |
+
class DFL(nn.Module):
|
18 |
+
"""
|
19 |
+
Integral module of Distribution Focal Loss (DFL).
|
20 |
+
Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, c1=16):
|
24 |
+
"""Initialize a convolutional layer with a given number of input channels."""
|
25 |
+
super().__init__()
|
26 |
+
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
27 |
+
x = torch.arange(c1, dtype=torch.float)
|
28 |
+
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
|
29 |
+
self.c1 = c1
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
"""Applies a transformer layer on input tensor 'x' and returns a tensor."""
|
33 |
+
b, c, a = x.shape # batch, channels, anchors
|
34 |
+
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
|
35 |
+
# return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
|
36 |
+
|
37 |
+
|
38 |
+
class Proto(nn.Module):
|
39 |
+
"""YOLOv8 mask Proto module for segmentation models."""
|
40 |
+
|
41 |
+
def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
|
42 |
+
super().__init__()
|
43 |
+
self.cv1 = Conv(c1, c_, k=3)
|
44 |
+
self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')
|
45 |
+
self.cv2 = Conv(c_, c_, k=3)
|
46 |
+
self.cv3 = Conv(c_, c2)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
"""Performs a forward pass through layers using an upsampled input image."""
|
50 |
+
return self.cv3(self.cv2(self.upsample(self.cv1(x))))
|
51 |
+
|
52 |
+
|
53 |
+
class HGStem(nn.Module):
|
54 |
+
"""StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.
|
55 |
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, c1, cm, c2):
|
59 |
+
super().__init__()
|
60 |
+
self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU())
|
61 |
+
self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU())
|
62 |
+
self.stem2b = Conv(cm // 2, cm, 2, 1, 0, act=nn.ReLU())
|
63 |
+
self.stem3 = Conv(cm * 2, cm, 3, 2, act=nn.ReLU())
|
64 |
+
self.stem4 = Conv(cm, c2, 1, 1, act=nn.ReLU())
|
65 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
"""Forward pass of a PPHGNetV2 backbone layer."""
|
69 |
+
x = self.stem1(x)
|
70 |
+
x = F.pad(x, [0, 1, 0, 1])
|
71 |
+
x2 = self.stem2a(x)
|
72 |
+
x2 = F.pad(x2, [0, 1, 0, 1])
|
73 |
+
x2 = self.stem2b(x2)
|
74 |
+
x1 = self.pool(x)
|
75 |
+
x = torch.cat([x1, x2], dim=1)
|
76 |
+
x = self.stem3(x)
|
77 |
+
x = self.stem4(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class HGBlock(nn.Module):
|
82 |
+
"""HG_Block of PPHGNetV2 with 2 convolutions and LightConv.
|
83 |
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()):
|
87 |
+
super().__init__()
|
88 |
+
block = LightConv if lightconv else Conv
|
89 |
+
self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n))
|
90 |
+
self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act) # squeeze conv
|
91 |
+
self.ec = Conv(c2 // 2, c2, 1, 1, act=act) # excitation conv
|
92 |
+
self.add = shortcut and c1 == c2
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
"""Forward pass of a PPHGNetV2 backbone layer."""
|
96 |
+
y = [x]
|
97 |
+
y.extend(m(y[-1]) for m in self.m)
|
98 |
+
y = self.ec(self.sc(torch.cat(y, 1)))
|
99 |
+
return y + x if self.add else y
|
100 |
+
|
101 |
+
|
102 |
+
class SPP(nn.Module):
|
103 |
+
"""Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729."""
|
104 |
+
|
105 |
+
def __init__(self, c1, c2, k=(5, 9, 13)):
|
106 |
+
"""Initialize the SPP layer with input/output channels and pooling kernel sizes."""
|
107 |
+
super().__init__()
|
108 |
+
c_ = c1 // 2 # hidden channels
|
109 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
110 |
+
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
|
111 |
+
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
"""Forward pass of the SPP layer, performing spatial pyramid pooling."""
|
115 |
+
x = self.cv1(x)
|
116 |
+
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
|
117 |
+
|
118 |
+
|
119 |
+
class SPPF(nn.Module):
|
120 |
+
"""Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
|
121 |
+
|
122 |
+
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
|
123 |
+
super().__init__()
|
124 |
+
c_ = c1 // 2 # hidden channels
|
125 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
126 |
+
self.cv2 = Conv(c_ * 4, c2, 1, 1)
|
127 |
+
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
"""Forward pass through Ghost Convolution block."""
|
131 |
+
x = self.cv1(x)
|
132 |
+
y1 = self.m(x)
|
133 |
+
y2 = self.m(y1)
|
134 |
+
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
|
135 |
+
|
136 |
+
|
137 |
+
class C1(nn.Module):
|
138 |
+
"""CSP Bottleneck with 1 convolution."""
|
139 |
+
|
140 |
+
def __init__(self, c1, c2, n=1): # ch_in, ch_out, number
|
141 |
+
super().__init__()
|
142 |
+
self.cv1 = Conv(c1, c2, 1, 1)
|
143 |
+
self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n)))
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
"""Applies cross-convolutions to input in the C3 module."""
|
147 |
+
y = self.cv1(x)
|
148 |
+
return self.m(y) + y
|
149 |
+
|
150 |
+
|
151 |
+
class C2(nn.Module):
|
152 |
+
"""CSP Bottleneck with 2 convolutions."""
|
153 |
+
|
154 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
155 |
+
super().__init__()
|
156 |
+
self.c = int(c2 * e) # hidden channels
|
157 |
+
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
158 |
+
self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2)
|
159 |
+
# self.attention = ChannelAttention(2 * self.c) # or SpatialAttention()
|
160 |
+
self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
"""Forward pass through the CSP bottleneck with 2 convolutions."""
|
164 |
+
a, b = self.cv1(x).chunk(2, 1)
|
165 |
+
return self.cv2(torch.cat((self.m(a), b), 1))
|
166 |
+
|
167 |
+
|
168 |
+
class C2f(nn.Module):
|
169 |
+
"""CSP Bottleneck with 2 convolutions."""
|
170 |
+
|
171 |
+
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
172 |
+
super().__init__()
|
173 |
+
self.c = int(c2 * e) # hidden channels
|
174 |
+
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
175 |
+
self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
|
176 |
+
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
"""Forward pass through C2f layer."""
|
180 |
+
y = list(self.cv1(x).chunk(2, 1))
|
181 |
+
y.extend(m(y[-1]) for m in self.m)
|
182 |
+
return self.cv2(torch.cat(y, 1))
|
183 |
+
|
184 |
+
def forward_split(self, x):
|
185 |
+
"""Forward pass using split() instead of chunk()."""
|
186 |
+
y = list(self.cv1(x).split((self.c, self.c), 1))
|
187 |
+
y.extend(m(y[-1]) for m in self.m)
|
188 |
+
return self.cv2(torch.cat(y, 1))
|
189 |
+
|
190 |
+
|
191 |
+
class C3(nn.Module):
|
192 |
+
"""CSP Bottleneck with 3 convolutions."""
|
193 |
+
|
194 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
195 |
+
super().__init__()
|
196 |
+
c_ = int(c2 * e) # hidden channels
|
197 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
198 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
199 |
+
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
|
200 |
+
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
"""Forward pass through the CSP bottleneck with 2 convolutions."""
|
204 |
+
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
205 |
+
|
206 |
+
|
207 |
+
class C3x(C3):
|
208 |
+
"""C3 module with cross-convolutions."""
|
209 |
+
|
210 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
211 |
+
"""Initialize C3TR instance and set default parameters."""
|
212 |
+
super().__init__(c1, c2, n, shortcut, g, e)
|
213 |
+
self.c_ = int(c2 * e)
|
214 |
+
self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n)))
|
215 |
+
|
216 |
+
|
217 |
+
class RepC3(nn.Module):
|
218 |
+
"""Rep C3."""
|
219 |
+
|
220 |
+
def __init__(self, c1, c2, n=3, e=1.0):
|
221 |
+
super().__init__()
|
222 |
+
c_ = int(c2 * e) # hidden channels
|
223 |
+
self.cv1 = Conv(c1, c2, 1, 1)
|
224 |
+
self.cv2 = Conv(c1, c2, 1, 1)
|
225 |
+
self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)])
|
226 |
+
self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity()
|
227 |
+
|
228 |
+
def forward(self, x):
|
229 |
+
"""Forward pass of RT-DETR neck layer."""
|
230 |
+
return self.cv3(self.m(self.cv1(x)) + self.cv2(x))
|
231 |
+
|
232 |
+
|
233 |
+
class C3TR(C3):
|
234 |
+
"""C3 module with TransformerBlock()."""
|
235 |
+
|
236 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
237 |
+
"""Initialize C3Ghost module with GhostBottleneck()."""
|
238 |
+
super().__init__(c1, c2, n, shortcut, g, e)
|
239 |
+
c_ = int(c2 * e)
|
240 |
+
self.m = TransformerBlock(c_, c_, 4, n)
|
241 |
+
|
242 |
+
|
243 |
+
class C3Ghost(C3):
|
244 |
+
"""C3 module with GhostBottleneck()."""
|
245 |
+
|
246 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
247 |
+
"""Initialize 'SPP' module with various pooling sizes for spatial pyramid pooling."""
|
248 |
+
super().__init__(c1, c2, n, shortcut, g, e)
|
249 |
+
c_ = int(c2 * e) # hidden channels
|
250 |
+
self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
|
251 |
+
|
252 |
+
|
253 |
+
class GhostBottleneck(nn.Module):
|
254 |
+
"""Ghost Bottleneck https://github.com/huawei-noah/ghostnet."""
|
255 |
+
|
256 |
+
def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
|
257 |
+
super().__init__()
|
258 |
+
c_ = c2 // 2
|
259 |
+
self.conv = nn.Sequential(
|
260 |
+
GhostConv(c1, c_, 1, 1), # pw
|
261 |
+
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
|
262 |
+
GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
|
263 |
+
self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
|
264 |
+
act=False)) if s == 2 else nn.Identity()
|
265 |
+
|
266 |
+
def forward(self, x):
|
267 |
+
"""Applies skip connection and concatenation to input tensor."""
|
268 |
+
return self.conv(x) + self.shortcut(x)
|
269 |
+
|
270 |
+
|
271 |
+
class Bottleneck(nn.Module):
|
272 |
+
"""Standard bottleneck."""
|
273 |
+
|
274 |
+
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, groups, kernels, expand
|
275 |
+
super().__init__()
|
276 |
+
c_ = int(c2 * e) # hidden channels
|
277 |
+
self.cv1 = Conv(c1, c_, k[0], 1)
|
278 |
+
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
279 |
+
self.add = shortcut and c1 == c2
|
280 |
+
|
281 |
+
def forward(self, x):
|
282 |
+
"""'forward()' applies the YOLOv5 FPN to input data."""
|
283 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
284 |
+
|
285 |
+
|
286 |
+
class BottleneckCSP(nn.Module):
|
287 |
+
"""CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks."""
|
288 |
+
|
289 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
290 |
+
super().__init__()
|
291 |
+
c_ = int(c2 * e) # hidden channels
|
292 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
293 |
+
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
|
294 |
+
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
|
295 |
+
self.cv4 = Conv(2 * c_, c2, 1, 1)
|
296 |
+
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
|
297 |
+
self.act = nn.SiLU()
|
298 |
+
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
299 |
+
|
300 |
+
def forward(self, x):
|
301 |
+
"""Applies a CSP bottleneck with 3 convolutions."""
|
302 |
+
y1 = self.cv3(self.m(self.cv1(x)))
|
303 |
+
y2 = self.cv2(x)
|
304 |
+
return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
|
ultralytics/nn/modules/conv.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Convolution modules
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
__all__ = ('Conv', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
|
13 |
+
'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
|
14 |
+
|
15 |
+
|
16 |
+
def autopad(k, p=None, d=1): # kernel, padding, dilation
|
17 |
+
"""Pad to 'same' shape outputs."""
|
18 |
+
if d > 1:
|
19 |
+
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
|
20 |
+
if p is None:
|
21 |
+
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
|
22 |
+
return p
|
23 |
+
|
24 |
+
|
25 |
+
class Conv(nn.Module):
|
26 |
+
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
|
27 |
+
default_act = nn.SiLU() # default activation
|
28 |
+
|
29 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
30 |
+
"""Initialize Conv layer with given arguments including activation."""
|
31 |
+
super().__init__()
|
32 |
+
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
|
33 |
+
self.bn = nn.BatchNorm2d(c2)
|
34 |
+
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
"""Apply convolution, batch normalization and activation to input tensor."""
|
38 |
+
return self.act(self.bn(self.conv(x)))
|
39 |
+
|
40 |
+
def forward_fuse(self, x):
|
41 |
+
"""Perform transposed convolution of 2D data."""
|
42 |
+
return self.act(self.conv(x))
|
43 |
+
|
44 |
+
|
45 |
+
class Conv2(Conv):
|
46 |
+
"""Simplified RepConv module with Conv fusing."""
|
47 |
+
|
48 |
+
def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
|
49 |
+
"""Initialize Conv layer with given arguments including activation."""
|
50 |
+
super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
|
51 |
+
self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
"""Apply convolution, batch normalization and activation to input tensor."""
|
55 |
+
return self.act(self.bn(self.conv(x) + self.cv2(x)))
|
56 |
+
|
57 |
+
def fuse_convs(self):
|
58 |
+
"""Fuse parallel convolutions."""
|
59 |
+
w = torch.zeros_like(self.conv.weight.data)
|
60 |
+
i = [x // 2 for x in w.shape[2:]]
|
61 |
+
w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
|
62 |
+
self.conv.weight.data += w
|
63 |
+
self.__delattr__('cv2')
|
64 |
+
|
65 |
+
|
66 |
+
class LightConv(nn.Module):
|
67 |
+
"""Light convolution with args(ch_in, ch_out, kernel).
|
68 |
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, c1, c2, k=1, act=nn.ReLU()):
|
72 |
+
"""Initialize Conv layer with given arguments including activation."""
|
73 |
+
super().__init__()
|
74 |
+
self.conv1 = Conv(c1, c2, 1, act=False)
|
75 |
+
self.conv2 = DWConv(c2, c2, k, act=act)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
"""Apply 2 convolutions to input tensor."""
|
79 |
+
return self.conv2(self.conv1(x))
|
80 |
+
|
81 |
+
|
82 |
+
class DWConv(Conv):
|
83 |
+
"""Depth-wise convolution."""
|
84 |
+
|
85 |
+
def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
|
86 |
+
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
|
87 |
+
|
88 |
+
|
89 |
+
class DWConvTranspose2d(nn.ConvTranspose2d):
|
90 |
+
"""Depth-wise transpose convolution."""
|
91 |
+
|
92 |
+
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
|
93 |
+
super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
|
94 |
+
|
95 |
+
|
96 |
+
class ConvTranspose(nn.Module):
|
97 |
+
"""Convolution transpose 2d layer."""
|
98 |
+
default_act = nn.SiLU() # default activation
|
99 |
+
|
100 |
+
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
|
101 |
+
"""Initialize ConvTranspose2d layer with batch normalization and activation function."""
|
102 |
+
super().__init__()
|
103 |
+
self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
|
104 |
+
self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
|
105 |
+
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
"""Applies transposed convolutions, batch normalization and activation to input."""
|
109 |
+
return self.act(self.bn(self.conv_transpose(x)))
|
110 |
+
|
111 |
+
def forward_fuse(self, x):
|
112 |
+
"""Applies activation and convolution transpose operation to input."""
|
113 |
+
return self.act(self.conv_transpose(x))
|
114 |
+
|
115 |
+
|
116 |
+
class Focus(nn.Module):
|
117 |
+
"""Focus wh information into c-space."""
|
118 |
+
|
119 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
120 |
+
super().__init__()
|
121 |
+
self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)
|
122 |
+
# self.contract = Contract(gain=2)
|
123 |
+
|
124 |
+
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
|
125 |
+
return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
|
126 |
+
# return self.conv(self.contract(x))
|
127 |
+
|
128 |
+
|
129 |
+
class GhostConv(nn.Module):
|
130 |
+
"""Ghost Convolution https://github.com/huawei-noah/ghostnet."""
|
131 |
+
|
132 |
+
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
|
133 |
+
super().__init__()
|
134 |
+
c_ = c2 // 2 # hidden channels
|
135 |
+
self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
|
136 |
+
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
"""Forward propagation through a Ghost Bottleneck layer with skip connection."""
|
140 |
+
y = self.cv1(x)
|
141 |
+
return torch.cat((y, self.cv2(y)), 1)
|
142 |
+
|
143 |
+
|
144 |
+
class RepConv(nn.Module):
|
145 |
+
"""RepConv is a basic rep-style block, including training and deploy status
|
146 |
+
This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
|
147 |
+
"""
|
148 |
+
default_act = nn.SiLU() # default activation
|
149 |
+
|
150 |
+
def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
|
151 |
+
super().__init__()
|
152 |
+
assert k == 3 and p == 1
|
153 |
+
self.g = g
|
154 |
+
self.c1 = c1
|
155 |
+
self.c2 = c2
|
156 |
+
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
157 |
+
|
158 |
+
self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None
|
159 |
+
self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
|
160 |
+
self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
|
161 |
+
|
162 |
+
def forward_fuse(self, x):
|
163 |
+
"""Forward process"""
|
164 |
+
return self.act(self.conv(x))
|
165 |
+
|
166 |
+
def forward(self, x):
|
167 |
+
"""Forward process"""
|
168 |
+
id_out = 0 if self.bn is None else self.bn(x)
|
169 |
+
return self.act(self.conv1(x) + self.conv2(x) + id_out)
|
170 |
+
|
171 |
+
def get_equivalent_kernel_bias(self):
|
172 |
+
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
|
173 |
+
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
|
174 |
+
kernelid, biasid = self._fuse_bn_tensor(self.bn)
|
175 |
+
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
|
176 |
+
|
177 |
+
def _avg_to_3x3_tensor(self, avgp):
|
178 |
+
channels = self.c1
|
179 |
+
groups = self.g
|
180 |
+
kernel_size = avgp.kernel_size
|
181 |
+
input_dim = channels // groups
|
182 |
+
k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
|
183 |
+
k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
|
184 |
+
return k
|
185 |
+
|
186 |
+
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
187 |
+
if kernel1x1 is None:
|
188 |
+
return 0
|
189 |
+
else:
|
190 |
+
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
191 |
+
|
192 |
+
def _fuse_bn_tensor(self, branch):
|
193 |
+
if branch is None:
|
194 |
+
return 0, 0
|
195 |
+
if isinstance(branch, Conv):
|
196 |
+
kernel = branch.conv.weight
|
197 |
+
running_mean = branch.bn.running_mean
|
198 |
+
running_var = branch.bn.running_var
|
199 |
+
gamma = branch.bn.weight
|
200 |
+
beta = branch.bn.bias
|
201 |
+
eps = branch.bn.eps
|
202 |
+
elif isinstance(branch, nn.BatchNorm2d):
|
203 |
+
if not hasattr(self, 'id_tensor'):
|
204 |
+
input_dim = self.c1 // self.g
|
205 |
+
kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
|
206 |
+
for i in range(self.c1):
|
207 |
+
kernel_value[i, i % input_dim, 1, 1] = 1
|
208 |
+
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
|
209 |
+
kernel = self.id_tensor
|
210 |
+
running_mean = branch.running_mean
|
211 |
+
running_var = branch.running_var
|
212 |
+
gamma = branch.weight
|
213 |
+
beta = branch.bias
|
214 |
+
eps = branch.eps
|
215 |
+
std = (running_var + eps).sqrt()
|
216 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
217 |
+
return kernel * t, beta - running_mean * gamma / std
|
218 |
+
|
219 |
+
def fuse_convs(self):
|
220 |
+
if hasattr(self, 'conv'):
|
221 |
+
return
|
222 |
+
kernel, bias = self.get_equivalent_kernel_bias()
|
223 |
+
self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
|
224 |
+
out_channels=self.conv1.conv.out_channels,
|
225 |
+
kernel_size=self.conv1.conv.kernel_size,
|
226 |
+
stride=self.conv1.conv.stride,
|
227 |
+
padding=self.conv1.conv.padding,
|
228 |
+
dilation=self.conv1.conv.dilation,
|
229 |
+
groups=self.conv1.conv.groups,
|
230 |
+
bias=True).requires_grad_(False)
|
231 |
+
self.conv.weight.data = kernel
|
232 |
+
self.conv.bias.data = bias
|
233 |
+
for para in self.parameters():
|
234 |
+
para.detach_()
|
235 |
+
self.__delattr__('conv1')
|
236 |
+
self.__delattr__('conv2')
|
237 |
+
if hasattr(self, 'nm'):
|
238 |
+
self.__delattr__('nm')
|
239 |
+
if hasattr(self, 'bn'):
|
240 |
+
self.__delattr__('bn')
|
241 |
+
if hasattr(self, 'id_tensor'):
|
242 |
+
self.__delattr__('id_tensor')
|
243 |
+
|
244 |
+
|
245 |
+
class ChannelAttention(nn.Module):
|
246 |
+
"""Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""
|
247 |
+
|
248 |
+
def __init__(self, channels: int) -> None:
|
249 |
+
super().__init__()
|
250 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
251 |
+
self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
|
252 |
+
self.act = nn.Sigmoid()
|
253 |
+
|
254 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
255 |
+
return x * self.act(self.fc(self.pool(x)))
|
256 |
+
|
257 |
+
|
258 |
+
class SpatialAttention(nn.Module):
|
259 |
+
"""Spatial-attention module."""
|
260 |
+
|
261 |
+
def __init__(self, kernel_size=7):
|
262 |
+
"""Initialize Spatial-attention module with kernel size argument."""
|
263 |
+
super().__init__()
|
264 |
+
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
265 |
+
padding = 3 if kernel_size == 7 else 1
|
266 |
+
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
267 |
+
self.act = nn.Sigmoid()
|
268 |
+
|
269 |
+
def forward(self, x):
|
270 |
+
"""Apply channel and spatial attention on input for feature recalibration."""
|
271 |
+
return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
|
272 |
+
|
273 |
+
|
274 |
+
class CBAM(nn.Module):
|
275 |
+
"""Convolutional Block Attention Module."""
|
276 |
+
|
277 |
+
def __init__(self, c1, kernel_size=7): # ch_in, kernels
|
278 |
+
super().__init__()
|
279 |
+
self.channel_attention = ChannelAttention(c1)
|
280 |
+
self.spatial_attention = SpatialAttention(kernel_size)
|
281 |
+
|
282 |
+
def forward(self, x):
|
283 |
+
"""Applies the forward pass through C1 module."""
|
284 |
+
return self.spatial_attention(self.channel_attention(x))
|
285 |
+
|
286 |
+
|
287 |
+
class Concat(nn.Module):
|
288 |
+
"""Concatenate a list of tensors along dimension."""
|
289 |
+
|
290 |
+
def __init__(self, dimension=1):
|
291 |
+
"""Concatenates a list of tensors along a specified dimension."""
|
292 |
+
super().__init__()
|
293 |
+
self.d = dimension
|
294 |
+
|
295 |
+
def forward(self, x):
|
296 |
+
"""Forward pass for the YOLOv8 mask Proto module."""
|
297 |
+
return torch.cat(x, self.d)
|
ultralytics/nn/modules/head.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Model head modules
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.nn.init import constant_, xavier_uniform_
|
11 |
+
|
12 |
+
from ultralytics.yolo.utils.tal import dist2bbox, make_anchors
|
13 |
+
|
14 |
+
from .block import DFL, Proto
|
15 |
+
from .conv import Conv
|
16 |
+
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
|
17 |
+
from .utils import bias_init_with_prob, linear_init_
|
18 |
+
|
19 |
+
__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'RTDETRDecoder'
|
20 |
+
|
21 |
+
|
22 |
+
class Detect(nn.Module):
|
23 |
+
"""YOLOv8 Detect head for detection models."""
|
24 |
+
dynamic = False # force grid reconstruction
|
25 |
+
export = False # export mode
|
26 |
+
shape = None
|
27 |
+
anchors = torch.empty(0) # init
|
28 |
+
strides = torch.empty(0) # init
|
29 |
+
|
30 |
+
def __init__(self, nc=80, ch=()): # detection layer
|
31 |
+
super().__init__()
|
32 |
+
self.nc = nc # number of classes
|
33 |
+
self.nl = len(ch) # number of detection layers
|
34 |
+
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
|
35 |
+
self.no = nc + self.reg_max * 4 # number of outputs per anchor
|
36 |
+
self.stride = torch.zeros(self.nl) # strides computed during build
|
37 |
+
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], self.nc) # channels
|
38 |
+
self.cv2 = nn.ModuleList(
|
39 |
+
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
|
40 |
+
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
|
41 |
+
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
"""Concatenates and returns predicted bounding boxes and class probabilities."""
|
45 |
+
shape = x[0].shape # BCHW
|
46 |
+
for i in range(self.nl):
|
47 |
+
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
|
48 |
+
if self.training:
|
49 |
+
return x
|
50 |
+
elif self.dynamic or self.shape != shape:
|
51 |
+
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
52 |
+
self.shape = shape
|
53 |
+
|
54 |
+
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
55 |
+
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
|
56 |
+
box = x_cat[:, :self.reg_max * 4]
|
57 |
+
cls = x_cat[:, self.reg_max * 4:]
|
58 |
+
else:
|
59 |
+
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
60 |
+
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
61 |
+
y = torch.cat((dbox, cls.sigmoid()), 1)
|
62 |
+
return y if self.export else (y, x)
|
63 |
+
|
64 |
+
def bias_init(self):
|
65 |
+
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
66 |
+
m = self # self.model[-1] # Detect() module
|
67 |
+
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
68 |
+
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
69 |
+
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
70 |
+
a[-1].bias.data[:] = 1.0 # box
|
71 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
72 |
+
|
73 |
+
|
74 |
+
class Segment(Detect):
|
75 |
+
"""YOLOv8 Segment head for segmentation models."""
|
76 |
+
|
77 |
+
def __init__(self, nc=80, nm=32, npr=256, ch=()):
|
78 |
+
"""Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
|
79 |
+
super().__init__(nc, ch)
|
80 |
+
self.nm = nm # number of masks
|
81 |
+
self.npr = npr # number of protos
|
82 |
+
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
83 |
+
self.detect = Detect.forward
|
84 |
+
|
85 |
+
c4 = max(ch[0] // 4, self.nm)
|
86 |
+
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
"""Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
|
90 |
+
p = self.proto(x[0]) # mask protos
|
91 |
+
bs = p.shape[0] # batch size
|
92 |
+
|
93 |
+
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
94 |
+
x = self.detect(self, x)
|
95 |
+
if self.training:
|
96 |
+
return x, mc, p
|
97 |
+
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
|
98 |
+
|
99 |
+
|
100 |
+
class Pose(Detect):
|
101 |
+
"""YOLOv8 Pose head for keypoints models."""
|
102 |
+
|
103 |
+
def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
|
104 |
+
"""Initialize YOLO network with default parameters and Convolutional Layers."""
|
105 |
+
super().__init__(nc, ch)
|
106 |
+
self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
107 |
+
self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
|
108 |
+
self.detect = Detect.forward
|
109 |
+
|
110 |
+
c4 = max(ch[0] // 4, self.nk)
|
111 |
+
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
"""Perform forward pass through YOLO model and return predictions."""
|
115 |
+
bs = x[0].shape[0] # batch size
|
116 |
+
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
|
117 |
+
x = self.detect(self, x)
|
118 |
+
if self.training:
|
119 |
+
return x, kpt
|
120 |
+
pred_kpt = self.kpts_decode(bs, kpt)
|
121 |
+
return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
|
122 |
+
|
123 |
+
def kpts_decode(self, bs, kpts):
|
124 |
+
"""Decodes keypoints."""
|
125 |
+
ndim = self.kpt_shape[1]
|
126 |
+
if self.export: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
|
127 |
+
y = kpts.view(bs, *self.kpt_shape, -1)
|
128 |
+
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
|
129 |
+
if ndim == 3:
|
130 |
+
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
|
131 |
+
return a.view(bs, self.nk, -1)
|
132 |
+
else:
|
133 |
+
y = kpts.clone()
|
134 |
+
if ndim == 3:
|
135 |
+
y[:, 2::3].sigmoid_() # inplace sigmoid
|
136 |
+
y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
|
137 |
+
y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
|
138 |
+
return y
|
139 |
+
|
140 |
+
|
141 |
+
class Classify(nn.Module):
|
142 |
+
"""YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
|
143 |
+
|
144 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
145 |
+
super().__init__()
|
146 |
+
c_ = 1280 # efficientnet_b0 size
|
147 |
+
self.conv = Conv(c1, c_, k, s, p, g)
|
148 |
+
self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
|
149 |
+
self.drop = nn.Dropout(p=0.0, inplace=True)
|
150 |
+
self.linear = nn.Linear(c_, c2) # to x(b,c2)
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
"""Performs a forward pass of the YOLO model on input image data."""
|
154 |
+
if isinstance(x, list):
|
155 |
+
x = torch.cat(x, 1)
|
156 |
+
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|
157 |
+
return x if self.training else x.softmax(1)
|
158 |
+
|
159 |
+
|
160 |
+
class RTDETRDecoder(nn.Module):
|
161 |
+
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
nc=80,
|
165 |
+
ch=(512, 1024, 2048),
|
166 |
+
hd=256, # hidden dim
|
167 |
+
nq=300, # num queries
|
168 |
+
ndp=4, # num decoder points
|
169 |
+
nh=8, # num head
|
170 |
+
ndl=6, # num decoder layers
|
171 |
+
d_ffn=1024, # dim of feedforward
|
172 |
+
dropout=0.,
|
173 |
+
act=nn.ReLU(),
|
174 |
+
eval_idx=-1,
|
175 |
+
# training args
|
176 |
+
nd=100, # num denoising
|
177 |
+
label_noise_ratio=0.5,
|
178 |
+
box_noise_scale=1.0,
|
179 |
+
learnt_init_query=False):
|
180 |
+
super().__init__()
|
181 |
+
self.hidden_dim = hd
|
182 |
+
self.nhead = nh
|
183 |
+
self.nl = len(ch) # num level
|
184 |
+
self.nc = nc
|
185 |
+
self.num_queries = nq
|
186 |
+
self.num_decoder_layers = ndl
|
187 |
+
|
188 |
+
# backbone feature projection
|
189 |
+
self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
|
190 |
+
# NOTE: simplified version but it's not consistent with .pt weights.
|
191 |
+
# self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)
|
192 |
+
|
193 |
+
# Transformer module
|
194 |
+
decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
|
195 |
+
self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)
|
196 |
+
|
197 |
+
# denoising part
|
198 |
+
self.denoising_class_embed = nn.Embedding(nc, hd)
|
199 |
+
self.num_denoising = nd
|
200 |
+
self.label_noise_ratio = label_noise_ratio
|
201 |
+
self.box_noise_scale = box_noise_scale
|
202 |
+
|
203 |
+
# decoder embedding
|
204 |
+
self.learnt_init_query = learnt_init_query
|
205 |
+
if learnt_init_query:
|
206 |
+
self.tgt_embed = nn.Embedding(nq, hd)
|
207 |
+
self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)
|
208 |
+
|
209 |
+
# encoder head
|
210 |
+
self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
|
211 |
+
self.enc_score_head = nn.Linear(hd, nc)
|
212 |
+
self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)
|
213 |
+
|
214 |
+
# decoder head
|
215 |
+
self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
|
216 |
+
self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])
|
217 |
+
|
218 |
+
self._reset_parameters()
|
219 |
+
|
220 |
+
def forward(self, x, batch=None):
|
221 |
+
from ultralytics.vit.utils.ops import get_cdn_group
|
222 |
+
|
223 |
+
# input projection and embedding
|
224 |
+
feats, shapes = self._get_encoder_input(x)
|
225 |
+
|
226 |
+
# prepare denoising training
|
227 |
+
dn_embed, dn_bbox, attn_mask, dn_meta = \
|
228 |
+
get_cdn_group(batch,
|
229 |
+
self.nc,
|
230 |
+
self.num_queries,
|
231 |
+
self.denoising_class_embed.weight,
|
232 |
+
self.num_denoising,
|
233 |
+
self.label_noise_ratio,
|
234 |
+
self.box_noise_scale,
|
235 |
+
self.training)
|
236 |
+
|
237 |
+
embed, refer_bbox, enc_bboxes, enc_scores = \
|
238 |
+
self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
|
239 |
+
|
240 |
+
# decoder
|
241 |
+
dec_bboxes, dec_scores = self.decoder(embed,
|
242 |
+
refer_bbox,
|
243 |
+
feats,
|
244 |
+
shapes,
|
245 |
+
self.dec_bbox_head,
|
246 |
+
self.dec_score_head,
|
247 |
+
self.query_pos_head,
|
248 |
+
attn_mask=attn_mask)
|
249 |
+
if not self.training:
|
250 |
+
dec_scores = dec_scores.sigmoid_()
|
251 |
+
return dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
|
252 |
+
|
253 |
+
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
|
254 |
+
anchors = []
|
255 |
+
for i, (h, w) in enumerate(shapes):
|
256 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(end=h, dtype=dtype, device=device),
|
257 |
+
torch.arange(end=w, dtype=dtype, device=device),
|
258 |
+
indexing='ij')
|
259 |
+
grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
|
260 |
+
|
261 |
+
valid_WH = torch.tensor([h, w], dtype=dtype, device=device)
|
262 |
+
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
|
263 |
+
wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i)
|
264 |
+
anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
|
265 |
+
|
266 |
+
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
|
267 |
+
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
|
268 |
+
anchors = torch.log(anchors / (1 - anchors))
|
269 |
+
anchors = torch.where(valid_mask, anchors, torch.inf)
|
270 |
+
return anchors, valid_mask
|
271 |
+
|
272 |
+
def _get_encoder_input(self, x):
|
273 |
+
# get projection features
|
274 |
+
x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
|
275 |
+
# get encoder inputs
|
276 |
+
feats = []
|
277 |
+
shapes = []
|
278 |
+
for feat in x:
|
279 |
+
h, w = feat.shape[2:]
|
280 |
+
# [b, c, h, w] -> [b, h*w, c]
|
281 |
+
feats.append(feat.flatten(2).permute(0, 2, 1))
|
282 |
+
# [nl, 2]
|
283 |
+
shapes.append([h, w])
|
284 |
+
|
285 |
+
# [b, h*w, c]
|
286 |
+
feats = torch.cat(feats, 1)
|
287 |
+
return feats, shapes
|
288 |
+
|
289 |
+
def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
|
290 |
+
bs = len(feats)
|
291 |
+
# prepare input for decoder
|
292 |
+
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
|
293 |
+
features = self.enc_output(torch.where(valid_mask, feats, 0)) # bs, h*w, 256
|
294 |
+
|
295 |
+
enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
|
296 |
+
# dynamic anchors + static content
|
297 |
+
enc_outputs_bboxes = self.enc_bbox_head(features) + anchors # (bs, h*w, 4)
|
298 |
+
|
299 |
+
# query selection
|
300 |
+
# (bs, num_queries)
|
301 |
+
topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
|
302 |
+
# (bs, num_queries)
|
303 |
+
batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
|
304 |
+
|
305 |
+
# Unsigmoided
|
306 |
+
refer_bbox = enc_outputs_bboxes[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
307 |
+
# refer_bbox = torch.gather(enc_outputs_bboxes, 1, topk_ind.reshape(bs, self.num_queries).unsqueeze(-1).repeat(1, 1, 4))
|
308 |
+
|
309 |
+
enc_bboxes = refer_bbox.sigmoid()
|
310 |
+
if dn_bbox is not None:
|
311 |
+
refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
|
312 |
+
if self.training:
|
313 |
+
refer_bbox = refer_bbox.detach()
|
314 |
+
enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
315 |
+
|
316 |
+
if self.learnt_init_query:
|
317 |
+
embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
|
318 |
+
else:
|
319 |
+
embeddings = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
320 |
+
if self.training:
|
321 |
+
embeddings = embeddings.detach()
|
322 |
+
if dn_embed is not None:
|
323 |
+
embeddings = torch.cat([dn_embed, embeddings], 1)
|
324 |
+
|
325 |
+
return embeddings, refer_bbox, enc_bboxes, enc_scores
|
326 |
+
|
327 |
+
# TODO
|
328 |
+
def _reset_parameters(self):
|
329 |
+
# class and bbox head init
|
330 |
+
bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
|
331 |
+
# NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets.
|
332 |
+
# linear_init_(self.enc_score_head)
|
333 |
+
constant_(self.enc_score_head.bias, bias_cls)
|
334 |
+
constant_(self.enc_bbox_head.layers[-1].weight, 0.)
|
335 |
+
constant_(self.enc_bbox_head.layers[-1].bias, 0.)
|
336 |
+
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
|
337 |
+
# linear_init_(cls_)
|
338 |
+
constant_(cls_.bias, bias_cls)
|
339 |
+
constant_(reg_.layers[-1].weight, 0.)
|
340 |
+
constant_(reg_.layers[-1].bias, 0.)
|
341 |
+
|
342 |
+
linear_init_(self.enc_output[0])
|
343 |
+
xavier_uniform_(self.enc_output[0].weight)
|
344 |
+
if self.learnt_init_query:
|
345 |
+
xavier_uniform_(self.tgt_embed.weight)
|
346 |
+
xavier_uniform_(self.query_pos_head.layers[0].weight)
|
347 |
+
xavier_uniform_(self.query_pos_head.layers[1].weight)
|
348 |
+
for layer in self.input_proj:
|
349 |
+
xavier_uniform_(layer[0].weight)
|
ultralytics/nn/modules/transformer.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Transformer modules
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.nn.init import constant_, xavier_uniform_
|
12 |
+
|
13 |
+
from .conv import Conv
|
14 |
+
from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch
|
15 |
+
|
16 |
+
__all__ = ('TransformerEncoderLayer', 'TransformerLayer', 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'AIFI',
|
17 |
+
'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP')
|
18 |
+
|
19 |
+
|
20 |
+
class TransformerEncoderLayer(nn.Module):
|
21 |
+
"""Transformer Encoder."""
|
22 |
+
|
23 |
+
def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False):
|
24 |
+
super().__init__()
|
25 |
+
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
|
26 |
+
# Implementation of Feedforward model
|
27 |
+
self.fc1 = nn.Linear(c1, cm)
|
28 |
+
self.fc2 = nn.Linear(cm, c1)
|
29 |
+
|
30 |
+
self.norm1 = nn.LayerNorm(c1)
|
31 |
+
self.norm2 = nn.LayerNorm(c1)
|
32 |
+
self.dropout = nn.Dropout(dropout)
|
33 |
+
self.dropout1 = nn.Dropout(dropout)
|
34 |
+
self.dropout2 = nn.Dropout(dropout)
|
35 |
+
|
36 |
+
self.act = act
|
37 |
+
self.normalize_before = normalize_before
|
38 |
+
|
39 |
+
def with_pos_embed(self, tensor, pos=None):
|
40 |
+
"""Add position embeddings if given."""
|
41 |
+
return tensor if pos is None else tensor + pos
|
42 |
+
|
43 |
+
def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
44 |
+
q = k = self.with_pos_embed(src, pos)
|
45 |
+
src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
46 |
+
src = src + self.dropout1(src2)
|
47 |
+
src = self.norm1(src)
|
48 |
+
src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
|
49 |
+
src = src + self.dropout2(src2)
|
50 |
+
src = self.norm2(src)
|
51 |
+
return src
|
52 |
+
|
53 |
+
def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
54 |
+
src2 = self.norm1(src)
|
55 |
+
q = k = self.with_pos_embed(src2, pos)
|
56 |
+
src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
57 |
+
src = src + self.dropout1(src2)
|
58 |
+
src2 = self.norm2(src)
|
59 |
+
src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
|
60 |
+
src = src + self.dropout2(src2)
|
61 |
+
return src
|
62 |
+
|
63 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
|
64 |
+
"""Forward propagates the input through the encoder module."""
|
65 |
+
if self.normalize_before:
|
66 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
67 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
68 |
+
|
69 |
+
|
70 |
+
class AIFI(TransformerEncoderLayer):
|
71 |
+
|
72 |
+
def __init__(self, c1, cm=2048, num_heads=8, dropout=0, act=nn.GELU(), normalize_before=False):
|
73 |
+
super().__init__(c1, cm, num_heads, dropout, act, normalize_before)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
c, h, w = x.shape[1:]
|
77 |
+
pos_embed = self.build_2d_sincos_position_embedding(w, h, c)
|
78 |
+
# flatten [B, C, H, W] to [B, HxW, C]
|
79 |
+
x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype))
|
80 |
+
return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous()
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.):
|
84 |
+
grid_w = torch.arange(int(w), dtype=torch.float32)
|
85 |
+
grid_h = torch.arange(int(h), dtype=torch.float32)
|
86 |
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij')
|
87 |
+
assert embed_dim % 4 == 0, \
|
88 |
+
'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
|
89 |
+
pos_dim = embed_dim // 4
|
90 |
+
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
91 |
+
omega = 1. / (temperature ** omega)
|
92 |
+
|
93 |
+
out_w = grid_w.flatten()[..., None] @ omega[None]
|
94 |
+
out_h = grid_h.flatten()[..., None] @ omega[None]
|
95 |
+
|
96 |
+
return torch.concat([torch.sin(out_w), torch.cos(out_w),
|
97 |
+
torch.sin(out_h), torch.cos(out_h)], axis=1)[None, :, :]
|
98 |
+
|
99 |
+
|
100 |
+
class TransformerLayer(nn.Module):
|
101 |
+
"""Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)."""
|
102 |
+
|
103 |
+
def __init__(self, c, num_heads):
|
104 |
+
"""Initializes a self-attention mechanism using linear transformations and multi-head attention."""
|
105 |
+
super().__init__()
|
106 |
+
self.q = nn.Linear(c, c, bias=False)
|
107 |
+
self.k = nn.Linear(c, c, bias=False)
|
108 |
+
self.v = nn.Linear(c, c, bias=False)
|
109 |
+
self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
|
110 |
+
self.fc1 = nn.Linear(c, c, bias=False)
|
111 |
+
self.fc2 = nn.Linear(c, c, bias=False)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
"""Apply a transformer block to the input x and return the output."""
|
115 |
+
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
|
116 |
+
x = self.fc2(self.fc1(x)) + x
|
117 |
+
return x
|
118 |
+
|
119 |
+
|
120 |
+
class TransformerBlock(nn.Module):
|
121 |
+
"""Vision Transformer https://arxiv.org/abs/2010.11929."""
|
122 |
+
|
123 |
+
def __init__(self, c1, c2, num_heads, num_layers):
|
124 |
+
"""Initialize a Transformer module with position embedding and specified number of heads and layers."""
|
125 |
+
super().__init__()
|
126 |
+
self.conv = None
|
127 |
+
if c1 != c2:
|
128 |
+
self.conv = Conv(c1, c2)
|
129 |
+
self.linear = nn.Linear(c2, c2) # learnable position embedding
|
130 |
+
self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
|
131 |
+
self.c2 = c2
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
"""Forward propagates the input through the bottleneck module."""
|
135 |
+
if self.conv is not None:
|
136 |
+
x = self.conv(x)
|
137 |
+
b, _, w, h = x.shape
|
138 |
+
p = x.flatten(2).permute(2, 0, 1)
|
139 |
+
return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
|
140 |
+
|
141 |
+
|
142 |
+
class MLPBlock(nn.Module):
|
143 |
+
|
144 |
+
def __init__(self, embedding_dim, mlp_dim, act=nn.GELU):
|
145 |
+
super().__init__()
|
146 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
147 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
148 |
+
self.act = act()
|
149 |
+
|
150 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
151 |
+
return self.lin2(self.act(self.lin1(x)))
|
152 |
+
|
153 |
+
|
154 |
+
class MLP(nn.Module):
|
155 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
156 |
+
|
157 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
158 |
+
super().__init__()
|
159 |
+
self.num_layers = num_layers
|
160 |
+
h = [hidden_dim] * (num_layers - 1)
|
161 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
for i, layer in enumerate(self.layers):
|
165 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
166 |
+
return x
|
167 |
+
|
168 |
+
|
169 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
170 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
171 |
+
class LayerNorm2d(nn.Module):
|
172 |
+
|
173 |
+
def __init__(self, num_channels, eps=1e-6):
|
174 |
+
super().__init__()
|
175 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
176 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
177 |
+
self.eps = eps
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
u = x.mean(1, keepdim=True)
|
181 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
182 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
183 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
184 |
+
return x
|
185 |
+
|
186 |
+
|
187 |
+
class MSDeformAttn(nn.Module):
|
188 |
+
"""
|
189 |
+
Original Multi-Scale Deformable Attention Module.
|
190 |
+
https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
|
191 |
+
"""
|
192 |
+
|
193 |
+
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
|
194 |
+
super().__init__()
|
195 |
+
if d_model % n_heads != 0:
|
196 |
+
raise ValueError(f'd_model must be divisible by n_heads, but got {d_model} and {n_heads}')
|
197 |
+
_d_per_head = d_model // n_heads
|
198 |
+
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
|
199 |
+
assert _d_per_head * n_heads == d_model, '`d_model` must be divisible by `n_heads`'
|
200 |
+
|
201 |
+
self.im2col_step = 64
|
202 |
+
|
203 |
+
self.d_model = d_model
|
204 |
+
self.n_levels = n_levels
|
205 |
+
self.n_heads = n_heads
|
206 |
+
self.n_points = n_points
|
207 |
+
|
208 |
+
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
|
209 |
+
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
|
210 |
+
self.value_proj = nn.Linear(d_model, d_model)
|
211 |
+
self.output_proj = nn.Linear(d_model, d_model)
|
212 |
+
|
213 |
+
self._reset_parameters()
|
214 |
+
|
215 |
+
def _reset_parameters(self):
|
216 |
+
constant_(self.sampling_offsets.weight.data, 0.)
|
217 |
+
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
218 |
+
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
219 |
+
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(
|
220 |
+
1, self.n_levels, self.n_points, 1)
|
221 |
+
for i in range(self.n_points):
|
222 |
+
grid_init[:, :, i, :] *= i + 1
|
223 |
+
with torch.no_grad():
|
224 |
+
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
225 |
+
constant_(self.attention_weights.weight.data, 0.)
|
226 |
+
constant_(self.attention_weights.bias.data, 0.)
|
227 |
+
xavier_uniform_(self.value_proj.weight.data)
|
228 |
+
constant_(self.value_proj.bias.data, 0.)
|
229 |
+
xavier_uniform_(self.output_proj.weight.data)
|
230 |
+
constant_(self.output_proj.bias.data, 0.)
|
231 |
+
|
232 |
+
def forward(self, query, refer_bbox, value, value_shapes, value_mask=None):
|
233 |
+
"""
|
234 |
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
|
235 |
+
Args:
|
236 |
+
query (torch.Tensor): [bs, query_length, C]
|
237 |
+
refer_bbox (torch.Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
|
238 |
+
bottom-right (1, 1), including padding area
|
239 |
+
value (torch.Tensor): [bs, value_length, C]
|
240 |
+
value_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
241 |
+
value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
output (Tensor): [bs, Length_{query}, C]
|
245 |
+
"""
|
246 |
+
bs, len_q = query.shape[:2]
|
247 |
+
len_v = value.shape[1]
|
248 |
+
assert sum(s[0] * s[1] for s in value_shapes) == len_v
|
249 |
+
|
250 |
+
value = self.value_proj(value)
|
251 |
+
if value_mask is not None:
|
252 |
+
value = value.masked_fill(value_mask[..., None], float(0))
|
253 |
+
value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads)
|
254 |
+
sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, self.n_levels, self.n_points, 2)
|
255 |
+
attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points)
|
256 |
+
attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points)
|
257 |
+
# N, Len_q, n_heads, n_levels, n_points, 2
|
258 |
+
num_points = refer_bbox.shape[-1]
|
259 |
+
if num_points == 2:
|
260 |
+
offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)
|
261 |
+
add = sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
262 |
+
sampling_locations = refer_bbox[:, :, None, :, None, :] + add
|
263 |
+
elif num_points == 4:
|
264 |
+
add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
|
265 |
+
sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
|
266 |
+
else:
|
267 |
+
raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.')
|
268 |
+
output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
|
269 |
+
output = self.output_proj(output)
|
270 |
+
return output
|
271 |
+
|
272 |
+
|
273 |
+
class DeformableTransformerDecoderLayer(nn.Module):
|
274 |
+
"""
|
275 |
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
|
276 |
+
https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py
|
277 |
+
"""
|
278 |
+
|
279 |
+
def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0., act=nn.ReLU(), n_levels=4, n_points=4):
|
280 |
+
super().__init__()
|
281 |
+
|
282 |
+
# self attention
|
283 |
+
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
284 |
+
self.dropout1 = nn.Dropout(dropout)
|
285 |
+
self.norm1 = nn.LayerNorm(d_model)
|
286 |
+
|
287 |
+
# cross attention
|
288 |
+
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
289 |
+
self.dropout2 = nn.Dropout(dropout)
|
290 |
+
self.norm2 = nn.LayerNorm(d_model)
|
291 |
+
|
292 |
+
# ffn
|
293 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
294 |
+
self.act = act
|
295 |
+
self.dropout3 = nn.Dropout(dropout)
|
296 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
297 |
+
self.dropout4 = nn.Dropout(dropout)
|
298 |
+
self.norm3 = nn.LayerNorm(d_model)
|
299 |
+
|
300 |
+
@staticmethod
|
301 |
+
def with_pos_embed(tensor, pos):
|
302 |
+
return tensor if pos is None else tensor + pos
|
303 |
+
|
304 |
+
def forward_ffn(self, tgt):
|
305 |
+
tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))
|
306 |
+
tgt = tgt + self.dropout4(tgt2)
|
307 |
+
tgt = self.norm3(tgt)
|
308 |
+
return tgt
|
309 |
+
|
310 |
+
def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
|
311 |
+
# self attention
|
312 |
+
q = k = self.with_pos_embed(embed, query_pos)
|
313 |
+
tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1),
|
314 |
+
attn_mask=attn_mask)[0].transpose(0, 1)
|
315 |
+
embed = embed + self.dropout1(tgt)
|
316 |
+
embed = self.norm1(embed)
|
317 |
+
|
318 |
+
# cross attention
|
319 |
+
tgt = self.cross_attn(self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes,
|
320 |
+
padding_mask)
|
321 |
+
embed = embed + self.dropout2(tgt)
|
322 |
+
embed = self.norm2(embed)
|
323 |
+
|
324 |
+
# ffn
|
325 |
+
embed = self.forward_ffn(embed)
|
326 |
+
|
327 |
+
return embed
|
328 |
+
|
329 |
+
|
330 |
+
class DeformableTransformerDecoder(nn.Module):
|
331 |
+
"""
|
332 |
+
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
|
333 |
+
"""
|
334 |
+
|
335 |
+
def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1):
|
336 |
+
super().__init__()
|
337 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
338 |
+
self.num_layers = num_layers
|
339 |
+
self.hidden_dim = hidden_dim
|
340 |
+
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
|
341 |
+
|
342 |
+
def forward(
|
343 |
+
self,
|
344 |
+
embed, # decoder embeddings
|
345 |
+
refer_bbox, # anchor
|
346 |
+
feats, # image features
|
347 |
+
shapes, # feature shapes
|
348 |
+
bbox_head,
|
349 |
+
score_head,
|
350 |
+
pos_mlp,
|
351 |
+
attn_mask=None,
|
352 |
+
padding_mask=None):
|
353 |
+
output = embed
|
354 |
+
dec_bboxes = []
|
355 |
+
dec_cls = []
|
356 |
+
last_refined_bbox = None
|
357 |
+
refer_bbox = refer_bbox.sigmoid()
|
358 |
+
for i, layer in enumerate(self.layers):
|
359 |
+
output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))
|
360 |
+
|
361 |
+
# refine bboxes, (bs, num_queries+num_denoising, 4)
|
362 |
+
refined_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(refer_bbox))
|
363 |
+
|
364 |
+
if self.training:
|
365 |
+
dec_cls.append(score_head[i](output))
|
366 |
+
if i == 0:
|
367 |
+
dec_bboxes.append(refined_bbox)
|
368 |
+
else:
|
369 |
+
dec_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(last_refined_bbox)))
|
370 |
+
elif i == self.eval_idx:
|
371 |
+
dec_cls.append(score_head[i](output))
|
372 |
+
dec_bboxes.append(refined_bbox)
|
373 |
+
break
|
374 |
+
|
375 |
+
last_refined_bbox = refined_bbox
|
376 |
+
refer_bbox = refined_bbox.detach() if self.training else refined_bbox
|
377 |
+
|
378 |
+
return torch.stack(dec_bboxes), torch.stack(dec_cls)
|
ultralytics/nn/modules/utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
"""
|
3 |
+
Module utils
|
4 |
+
"""
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import math
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.nn.init import uniform_
|
14 |
+
|
15 |
+
__all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid'
|
16 |
+
|
17 |
+
|
18 |
+
def _get_clones(module, n):
|
19 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
|
20 |
+
|
21 |
+
|
22 |
+
def bias_init_with_prob(prior_prob=0.01):
|
23 |
+
"""initialize conv/fc bias value according to a given probability value."""
|
24 |
+
return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init
|
25 |
+
|
26 |
+
|
27 |
+
def linear_init_(module):
|
28 |
+
bound = 1 / math.sqrt(module.weight.shape[0])
|
29 |
+
uniform_(module.weight, -bound, bound)
|
30 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
31 |
+
uniform_(module.bias, -bound, bound)
|
32 |
+
|
33 |
+
|
34 |
+
def inverse_sigmoid(x, eps=1e-5):
|
35 |
+
x = x.clamp(min=0, max=1)
|
36 |
+
x1 = x.clamp(min=eps)
|
37 |
+
x2 = (1 - x).clamp(min=eps)
|
38 |
+
return torch.log(x1 / x2)
|
39 |
+
|
40 |
+
|
41 |
+
def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor,
|
42 |
+
sampling_locations: torch.Tensor,
|
43 |
+
attention_weights: torch.Tensor) -> torch.Tensor:
|
44 |
+
"""
|
45 |
+
Multi-scale deformable attention.
|
46 |
+
https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py
|
47 |
+
"""
|
48 |
+
|
49 |
+
bs, _, num_heads, embed_dims = value.shape
|
50 |
+
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
51 |
+
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
52 |
+
sampling_grids = 2 * sampling_locations - 1
|
53 |
+
sampling_value_list = []
|
54 |
+
for level, (H_, W_) in enumerate(value_spatial_shapes):
|
55 |
+
# bs, H_*W_, num_heads, embed_dims ->
|
56 |
+
# bs, H_*W_, num_heads*embed_dims ->
|
57 |
+
# bs, num_heads*embed_dims, H_*W_ ->
|
58 |
+
# bs*num_heads, embed_dims, H_, W_
|
59 |
+
value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_))
|
60 |
+
# bs, num_queries, num_heads, num_points, 2 ->
|
61 |
+
# bs, num_heads, num_queries, num_points, 2 ->
|
62 |
+
# bs*num_heads, num_queries, num_points, 2
|
63 |
+
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
|
64 |
+
# bs*num_heads, embed_dims, num_queries, num_points
|
65 |
+
sampling_value_l_ = F.grid_sample(value_l_,
|
66 |
+
sampling_grid_l_,
|
67 |
+
mode='bilinear',
|
68 |
+
padding_mode='zeros',
|
69 |
+
align_corners=False)
|
70 |
+
sampling_value_list.append(sampling_value_l_)
|
71 |
+
# (bs, num_queries, num_heads, num_levels, num_points) ->
|
72 |
+
# (bs, num_heads, num_queries, num_levels, num_points) ->
|
73 |
+
# (bs, num_heads, 1, num_queries, num_levels*num_points)
|
74 |
+
attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries,
|
75 |
+
num_levels * num_points)
|
76 |
+
output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(
|
77 |
+
bs, num_heads * embed_dims, num_queries))
|
78 |
+
return output.transpose(1, 2).contiguous()
|
ultralytics/nn/tasks.py
ADDED
@@ -0,0 +1,780 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
2 |
+
|
3 |
+
import contextlib
|
4 |
+
from copy import deepcopy
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
|
11 |
+
Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
|
12 |
+
Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
|
13 |
+
RTDETRDecoder, Segment)
|
14 |
+
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
15 |
+
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
|
16 |
+
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
|
17 |
+
from ultralytics.yolo.utils.plotting import feature_visualization
|
18 |
+
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
|
19 |
+
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
|
20 |
+
|
21 |
+
try:
|
22 |
+
import thop
|
23 |
+
except ImportError:
|
24 |
+
thop = None
|
25 |
+
|
26 |
+
|
27 |
+
class BaseModel(nn.Module):
|
28 |
+
"""
|
29 |
+
The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def forward(self, x, *args, **kwargs):
|
33 |
+
"""
|
34 |
+
Forward pass of the model on a single scale.
|
35 |
+
Wrapper for `_forward_once` method.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
(torch.Tensor): The output of the network.
|
42 |
+
"""
|
43 |
+
if isinstance(x, dict): # for cases of training and validating while training.
|
44 |
+
return self.loss(x, *args, **kwargs)
|
45 |
+
return self.predict(x, *args, **kwargs)
|
46 |
+
|
47 |
+
def predict(self, x, profile=False, visualize=False, augment=False):
|
48 |
+
"""
|
49 |
+
Perform a forward pass through the network.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
x (torch.Tensor): The input tensor to the model.
|
53 |
+
profile (bool): Print the computation time of each layer if True, defaults to False.
|
54 |
+
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
55 |
+
augment (bool): Augment image during prediction, defaults to False.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
(torch.Tensor): The last output of the model.
|
59 |
+
"""
|
60 |
+
if augment:
|
61 |
+
return self._predict_augment(x)
|
62 |
+
return self._predict_once(x, profile, visualize)
|
63 |
+
|
64 |
+
def _predict_once(self, x, profile=False, visualize=False):
|
65 |
+
"""
|
66 |
+
Perform a forward pass through the network.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
x (torch.Tensor): The input tensor to the model.
|
70 |
+
profile (bool): Print the computation time of each layer if True, defaults to False.
|
71 |
+
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
(torch.Tensor): The last output of the model.
|
75 |
+
"""
|
76 |
+
y, dt = [], [] # outputs
|
77 |
+
for m in self.model:
|
78 |
+
if m.f != -1: # if not from previous layer
|
79 |
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
80 |
+
if profile:
|
81 |
+
self._profile_one_layer(m, x, dt)
|
82 |
+
x = m(x) # run
|
83 |
+
y.append(x if m.i in self.save else None) # save output
|
84 |
+
if visualize:
|
85 |
+
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
86 |
+
return x
|
87 |
+
|
88 |
+
def _predict_augment(self, x):
|
89 |
+
"""Perform augmentations on input image x and return augmented inference."""
|
90 |
+
LOGGER.warning(
|
91 |
+
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
92 |
+
)
|
93 |
+
return self._predict_once(x)
|
94 |
+
|
95 |
+
def _profile_one_layer(self, m, x, dt):
|
96 |
+
"""
|
97 |
+
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
98 |
+
Appends the results to the provided list.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
m (nn.Module): The layer to be profiled.
|
102 |
+
x (torch.Tensor): The input data to the layer.
|
103 |
+
dt (list): A list to store the computation time of the layer.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
None
|
107 |
+
"""
|
108 |
+
c = m == self.model[-1] # is final layer, copy input as inplace fix
|
109 |
+
o = thop.profile(m, inputs=[x.clone() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
110 |
+
t = time_sync()
|
111 |
+
for _ in range(10):
|
112 |
+
m(x.clone() if c else x)
|
113 |
+
dt.append((time_sync() - t) * 100)
|
114 |
+
if m == self.model[0]:
|
115 |
+
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
116 |
+
LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
|
117 |
+
if c:
|
118 |
+
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
119 |
+
|
120 |
+
def fuse(self, verbose=True):
|
121 |
+
"""
|
122 |
+
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
|
123 |
+
computation efficiency.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
(nn.Module): The fused model is returned.
|
127 |
+
"""
|
128 |
+
if not self.is_fused():
|
129 |
+
for m in self.model.modules():
|
130 |
+
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
|
131 |
+
if isinstance(m, Conv2):
|
132 |
+
m.fuse_convs()
|
133 |
+
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
134 |
+
delattr(m, 'bn') # remove batchnorm
|
135 |
+
m.forward = m.forward_fuse # update forward
|
136 |
+
if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
|
137 |
+
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
|
138 |
+
delattr(m, 'bn') # remove batchnorm
|
139 |
+
m.forward = m.forward_fuse # update forward
|
140 |
+
if isinstance(m, RepConv):
|
141 |
+
m.fuse_convs()
|
142 |
+
m.forward = m.forward_fuse # update forward
|
143 |
+
self.info(verbose=verbose)
|
144 |
+
|
145 |
+
return self
|
146 |
+
|
147 |
+
def is_fused(self, thresh=10):
|
148 |
+
"""
|
149 |
+
Check if the model has less than a certain threshold of BatchNorm layers.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
156 |
+
"""
|
157 |
+
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
158 |
+
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
159 |
+
|
160 |
+
def info(self, detailed=False, verbose=True, imgsz=640):
|
161 |
+
"""
|
162 |
+
Prints model information
|
163 |
+
|
164 |
+
Args:
|
165 |
+
verbose (bool): if True, prints out the model information. Defaults to False
|
166 |
+
imgsz (int): the size of the image that the model will be trained on. Defaults to 640
|
167 |
+
"""
|
168 |
+
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
|
169 |
+
|
170 |
+
def _apply(self, fn):
|
171 |
+
"""
|
172 |
+
`_apply()` is a function that applies a function to all the tensors in the model that are not
|
173 |
+
parameters or registered buffers
|
174 |
+
|
175 |
+
Args:
|
176 |
+
fn: the function to apply to the model
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
A model that is a Detect() object.
|
180 |
+
"""
|
181 |
+
self = super()._apply(fn)
|
182 |
+
m = self.model[-1] # Detect()
|
183 |
+
if isinstance(m, (Detect, Segment)):
|
184 |
+
m.stride = fn(m.stride)
|
185 |
+
m.anchors = fn(m.anchors)
|
186 |
+
m.strides = fn(m.strides)
|
187 |
+
return self
|
188 |
+
|
189 |
+
def load(self, weights, verbose=True):
|
190 |
+
"""Load the weights into the model.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
|
194 |
+
verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
|
195 |
+
"""
|
196 |
+
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
197 |
+
csd = model.float().state_dict() # checkpoint state_dict as FP32
|
198 |
+
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
199 |
+
self.load_state_dict(csd, strict=False) # load
|
200 |
+
if verbose:
|
201 |
+
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
|
202 |
+
|
203 |
+
def loss(self, batch, preds=None):
|
204 |
+
"""
|
205 |
+
Compute loss
|
206 |
+
|
207 |
+
Args:
|
208 |
+
batch (dict): Batch to compute loss on
|
209 |
+
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
210 |
+
"""
|
211 |
+
if not hasattr(self, 'criterion'):
|
212 |
+
self.criterion = self.init_criterion()
|
213 |
+
|
214 |
+
preds = self.forward(batch['img']) if preds is None else preds
|
215 |
+
return self.criterion(preds, batch)
|
216 |
+
|
217 |
+
def init_criterion(self):
|
218 |
+
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
|
219 |
+
|
220 |
+
|
221 |
+
class DetectionModel(BaseModel):
|
222 |
+
"""YOLOv8 detection model."""
|
223 |
+
|
224 |
+
def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
|
225 |
+
super().__init__()
|
226 |
+
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
227 |
+
|
228 |
+
# Define model
|
229 |
+
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
230 |
+
if nc and nc != self.yaml['nc']:
|
231 |
+
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
232 |
+
self.yaml['nc'] = nc # override yaml value
|
233 |
+
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
234 |
+
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
235 |
+
self.inplace = self.yaml.get('inplace', True)
|
236 |
+
|
237 |
+
# Build strides
|
238 |
+
m = self.model[-1] # Detect()
|
239 |
+
if isinstance(m, (Detect, Segment, Pose)):
|
240 |
+
s = 256 # 2x min stride
|
241 |
+
m.inplace = self.inplace
|
242 |
+
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose)) else self.forward(x)
|
243 |
+
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
|
244 |
+
self.stride = m.stride
|
245 |
+
m.bias_init() # only run once
|
246 |
+
else:
|
247 |
+
self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR
|
248 |
+
|
249 |
+
# Init weights, biases
|
250 |
+
initialize_weights(self)
|
251 |
+
if verbose:
|
252 |
+
self.info()
|
253 |
+
LOGGER.info('')
|
254 |
+
|
255 |
+
def _predict_augment(self, x):
|
256 |
+
"""Perform augmentations on input image x and return augmented inference and train outputs."""
|
257 |
+
img_size = x.shape[-2:] # height, width
|
258 |
+
s = [1, 0.83, 0.67] # scales
|
259 |
+
f = [None, 3, None] # flips (2-ud, 3-lr)
|
260 |
+
y = [] # outputs
|
261 |
+
for si, fi in zip(s, f):
|
262 |
+
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
|
263 |
+
yi = super().predict(xi)[0] # forward
|
264 |
+
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
265 |
+
yi = self._descale_pred(yi, fi, si, img_size)
|
266 |
+
y.append(yi)
|
267 |
+
y = self._clip_augmented(y) # clip augmented tails
|
268 |
+
return torch.cat(y, -1), None # augmented inference, train
|
269 |
+
|
270 |
+
@staticmethod
|
271 |
+
def _descale_pred(p, flips, scale, img_size, dim=1):
|
272 |
+
"""De-scale predictions following augmented inference (inverse operation)."""
|
273 |
+
p[:, :4] /= scale # de-scale
|
274 |
+
x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
|
275 |
+
if flips == 2:
|
276 |
+
y = img_size[0] - y # de-flip ud
|
277 |
+
elif flips == 3:
|
278 |
+
x = img_size[1] - x # de-flip lr
|
279 |
+
return torch.cat((x, y, wh, cls), dim)
|
280 |
+
|
281 |
+
def _clip_augmented(self, y):
|
282 |
+
"""Clip YOLOv5 augmented inference tails."""
|
283 |
+
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
284 |
+
g = sum(4 ** x for x in range(nl)) # grid points
|
285 |
+
e = 1 # exclude layer count
|
286 |
+
i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices
|
287 |
+
y[0] = y[0][..., :-i] # large
|
288 |
+
i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
|
289 |
+
y[-1] = y[-1][..., i:] # small
|
290 |
+
return y
|
291 |
+
|
292 |
+
def init_criterion(self):
|
293 |
+
return v8DetectionLoss(self)
|
294 |
+
|
295 |
+
|
296 |
+
class SegmentationModel(DetectionModel):
|
297 |
+
"""YOLOv8 segmentation model."""
|
298 |
+
|
299 |
+
def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
|
300 |
+
"""Initialize YOLOv8 segmentation model with given config and parameters."""
|
301 |
+
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
302 |
+
|
303 |
+
def init_criterion(self):
|
304 |
+
return v8SegmentationLoss(self)
|
305 |
+
|
306 |
+
def _predict_augment(self, x):
|
307 |
+
"""Perform augmentations on input image x and return augmented inference."""
|
308 |
+
LOGGER.warning(
|
309 |
+
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
310 |
+
)
|
311 |
+
return self._predict_once(x)
|
312 |
+
|
313 |
+
|
314 |
+
class PoseModel(DetectionModel):
|
315 |
+
"""YOLOv8 pose model."""
|
316 |
+
|
317 |
+
def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
318 |
+
"""Initialize YOLOv8 Pose model."""
|
319 |
+
if not isinstance(cfg, dict):
|
320 |
+
cfg = yaml_model_load(cfg) # load model YAML
|
321 |
+
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']):
|
322 |
+
LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
|
323 |
+
cfg['kpt_shape'] = data_kpt_shape
|
324 |
+
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
325 |
+
|
326 |
+
def init_criterion(self):
|
327 |
+
return v8PoseLoss(self)
|
328 |
+
|
329 |
+
def _predict_augment(self, x):
|
330 |
+
"""Perform augmentations on input image x and return augmented inference."""
|
331 |
+
LOGGER.warning(
|
332 |
+
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
333 |
+
)
|
334 |
+
return self._predict_once(x)
|
335 |
+
|
336 |
+
|
337 |
+
class ClassificationModel(BaseModel):
|
338 |
+
"""YOLOv8 classification model."""
|
339 |
+
|
340 |
+
def __init__(self,
|
341 |
+
cfg=None,
|
342 |
+
model=None,
|
343 |
+
ch=3,
|
344 |
+
nc=None,
|
345 |
+
cutoff=10,
|
346 |
+
verbose=True): # yaml, model, channels, number of classes, cutoff index, verbose flag
|
347 |
+
super().__init__()
|
348 |
+
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
|
349 |
+
|
350 |
+
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
351 |
+
"""Create a YOLOv5 classification model from a YOLOv5 detection model."""
|
352 |
+
from ultralytics.nn.autobackend import AutoBackend
|
353 |
+
if isinstance(model, AutoBackend):
|
354 |
+
model = model.model # unwrap DetectMultiBackend
|
355 |
+
model.model = model.model[:cutoff] # backbone
|
356 |
+
m = model.model[-1] # last layer
|
357 |
+
ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
|
358 |
+
c = Classify(ch, nc) # Classify()
|
359 |
+
c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
|
360 |
+
model.model[-1] = c # replace
|
361 |
+
self.model = model.model
|
362 |
+
self.stride = model.stride
|
363 |
+
self.save = []
|
364 |
+
self.nc = nc
|
365 |
+
|
366 |
+
def _from_yaml(self, cfg, ch, nc, verbose):
|
367 |
+
"""Set YOLOv8 model configurations and define the model architecture."""
|
368 |
+
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
369 |
+
|
370 |
+
# Define model
|
371 |
+
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
372 |
+
if nc and nc != self.yaml['nc']:
|
373 |
+
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
374 |
+
self.yaml['nc'] = nc # override yaml value
|
375 |
+
elif not nc and not self.yaml.get('nc', None):
|
376 |
+
raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.')
|
377 |
+
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
378 |
+
self.stride = torch.Tensor([1]) # no stride constraints
|
379 |
+
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
380 |
+
self.info()
|
381 |
+
|
382 |
+
@staticmethod
|
383 |
+
def reshape_outputs(model, nc):
|
384 |
+
"""Update a TorchVision classification model to class count 'n' if required."""
|
385 |
+
name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
|
386 |
+
if isinstance(m, Classify): # YOLO Classify() head
|
387 |
+
if m.linear.out_features != nc:
|
388 |
+
m.linear = nn.Linear(m.linear.in_features, nc)
|
389 |
+
elif isinstance(m, nn.Linear): # ResNet, EfficientNet
|
390 |
+
if m.out_features != nc:
|
391 |
+
setattr(model, name, nn.Linear(m.in_features, nc))
|
392 |
+
elif isinstance(m, nn.Sequential):
|
393 |
+
types = [type(x) for x in m]
|
394 |
+
if nn.Linear in types:
|
395 |
+
i = types.index(nn.Linear) # nn.Linear index
|
396 |
+
if m[i].out_features != nc:
|
397 |
+
m[i] = nn.Linear(m[i].in_features, nc)
|
398 |
+
elif nn.Conv2d in types:
|
399 |
+
i = types.index(nn.Conv2d) # nn.Conv2d index
|
400 |
+
if m[i].out_channels != nc:
|
401 |
+
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
|
402 |
+
|
403 |
+
def init_criterion(self):
|
404 |
+
"""Compute the classification loss between predictions and true labels."""
|
405 |
+
return v8ClassificationLoss()
|
406 |
+
|
407 |
+
|
408 |
+
class RTDETRDetectionModel(DetectionModel):
|
409 |
+
|
410 |
+
def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True):
|
411 |
+
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
412 |
+
|
413 |
+
def init_criterion(self):
|
414 |
+
"""Compute the classification loss between predictions and true labels."""
|
415 |
+
from ultralytics.vit.utils.loss import RTDETRDetectionLoss
|
416 |
+
|
417 |
+
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
|
418 |
+
|
419 |
+
def loss(self, batch, preds=None):
|
420 |
+
if not hasattr(self, 'criterion'):
|
421 |
+
self.criterion = self.init_criterion()
|
422 |
+
|
423 |
+
img = batch['img']
|
424 |
+
# NOTE: preprocess gt_bbox and gt_labels to list.
|
425 |
+
bs = len(img)
|
426 |
+
batch_idx = batch['batch_idx']
|
427 |
+
gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
|
428 |
+
targets = {
|
429 |
+
'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
|
430 |
+
'bboxes': batch['bboxes'].to(device=img.device),
|
431 |
+
'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
|
432 |
+
'gt_groups': gt_groups}
|
433 |
+
|
434 |
+
preds = self.predict(img, batch=targets) if preds is None else preds
|
435 |
+
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds
|
436 |
+
if dn_meta is None:
|
437 |
+
dn_bboxes, dn_scores = None, None
|
438 |
+
else:
|
439 |
+
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
|
440 |
+
dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
|
441 |
+
|
442 |
+
dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
|
443 |
+
dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
|
444 |
+
|
445 |
+
loss = self.criterion((dec_bboxes, dec_scores),
|
446 |
+
targets,
|
447 |
+
dn_bboxes=dn_bboxes,
|
448 |
+
dn_scores=dn_scores,
|
449 |
+
dn_meta=dn_meta)
|
450 |
+
# NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
|
451 |
+
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
|
452 |
+
device=img.device)
|
453 |
+
|
454 |
+
def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
|
455 |
+
"""
|
456 |
+
Perform a forward pass through the network.
|
457 |
+
|
458 |
+
Args:
|
459 |
+
x (torch.Tensor): The input tensor to the model
|
460 |
+
profile (bool): Print the computation time of each layer if True, defaults to False.
|
461 |
+
visualize (bool): Save the feature maps of the model if True, defaults to False
|
462 |
+
batch (dict): A dict including gt boxes and labels from dataloader.
|
463 |
+
|
464 |
+
Returns:
|
465 |
+
(torch.Tensor): The last output of the model.
|
466 |
+
"""
|
467 |
+
y, dt = [], [] # outputs
|
468 |
+
for m in self.model[:-1]: # except the head part
|
469 |
+
if m.f != -1: # if not from previous layer
|
470 |
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
471 |
+
if profile:
|
472 |
+
self._profile_one_layer(m, x, dt)
|
473 |
+
x = m(x) # run
|
474 |
+
y.append(x if m.i in self.save else None) # save output
|
475 |
+
if visualize:
|
476 |
+
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
477 |
+
head = self.model[-1]
|
478 |
+
x = head([y[j] for j in head.f], batch) # head inference
|
479 |
+
return x
|
480 |
+
|
481 |
+
|
482 |
+
class Ensemble(nn.ModuleList):
|
483 |
+
"""Ensemble of models."""
|
484 |
+
|
485 |
+
def __init__(self):
|
486 |
+
"""Initialize an ensemble of models."""
|
487 |
+
super().__init__()
|
488 |
+
|
489 |
+
def forward(self, x, augment=False, profile=False, visualize=False):
|
490 |
+
"""Function generates the YOLOv5 network's final layer."""
|
491 |
+
y = [module(x, augment, profile, visualize)[0] for module in self]
|
492 |
+
# y = torch.stack(y).max(0)[0] # max ensemble
|
493 |
+
# y = torch.stack(y).mean(0) # mean ensemble
|
494 |
+
y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C)
|
495 |
+
return y, None # inference, train output
|
496 |
+
|
497 |
+
|
498 |
+
# Functions ------------------------------------------------------------------------------------------------------------
|
499 |
+
|
500 |
+
|
501 |
+
def torch_safe_load(weight):
|
502 |
+
"""
|
503 |
+
This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
|
504 |
+
it catches the error, logs a warning message, and attempts to install the missing module via the
|
505 |
+
check_requirements() function. After installation, the function again attempts to load the model using torch.load().
|
506 |
+
|
507 |
+
Args:
|
508 |
+
weight (str): The file path of the PyTorch model.
|
509 |
+
|
510 |
+
Returns:
|
511 |
+
(dict): The loaded PyTorch model.
|
512 |
+
"""
|
513 |
+
from ultralytics.yolo.utils.downloads import attempt_download_asset
|
514 |
+
|
515 |
+
check_suffix(file=weight, suffix='.pt')
|
516 |
+
file = attempt_download_asset(weight) # search online if missing locally
|
517 |
+
try:
|
518 |
+
return torch.load(file, map_location='cpu'), file # load
|
519 |
+
except ModuleNotFoundError as e: # e.name is missing module name
|
520 |
+
if e.name == 'models':
|
521 |
+
raise TypeError(
|
522 |
+
emojis(f'ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained '
|
523 |
+
f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with '
|
524 |
+
f'YOLOv8 at https://github.com/ultralytics/ultralytics.'
|
525 |
+
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
526 |
+
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e
|
527 |
+
LOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
|
528 |
+
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
529 |
+
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
530 |
+
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
|
531 |
+
check_requirements(e.name) # install missing module
|
532 |
+
|
533 |
+
return torch.load(file, map_location='cpu'), file # load
|
534 |
+
|
535 |
+
|
536 |
+
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
537 |
+
"""Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
|
538 |
+
|
539 |
+
ensemble = Ensemble()
|
540 |
+
for w in weights if isinstance(weights, list) else [weights]:
|
541 |
+
ckpt, w = torch_safe_load(w) # load ckpt
|
542 |
+
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} if 'train_args' in ckpt else None # combined args
|
543 |
+
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
544 |
+
|
545 |
+
# Model compatibility updates
|
546 |
+
model.args = args # attach args to model
|
547 |
+
model.pt_path = w # attach *.pt file path to model
|
548 |
+
model.task = guess_model_task(model)
|
549 |
+
if not hasattr(model, 'stride'):
|
550 |
+
model.stride = torch.tensor([32.])
|
551 |
+
|
552 |
+
# Append
|
553 |
+
ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode
|
554 |
+
|
555 |
+
# Module compatibility updates
|
556 |
+
for m in ensemble.modules():
|
557 |
+
t = type(m)
|
558 |
+
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
|
559 |
+
m.inplace = inplace # torch 1.7.0 compatibility
|
560 |
+
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
561 |
+
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
562 |
+
|
563 |
+
# Return model
|
564 |
+
if len(ensemble) == 1:
|
565 |
+
return ensemble[-1]
|
566 |
+
|
567 |
+
# Return ensemble
|
568 |
+
LOGGER.info(f'Ensemble created with {weights}\n')
|
569 |
+
for k in 'names', 'nc', 'yaml':
|
570 |
+
setattr(ensemble, k, getattr(ensemble[0], k))
|
571 |
+
ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
|
572 |
+
assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}'
|
573 |
+
return ensemble
|
574 |
+
|
575 |
+
|
576 |
+
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
577 |
+
"""Loads a single model weights."""
|
578 |
+
ckpt, weight = torch_safe_load(weight) # load ckpt
|
579 |
+
args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args
|
580 |
+
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
581 |
+
|
582 |
+
# Model compatibility updates
|
583 |
+
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
584 |
+
model.pt_path = weight # attach *.pt file path to model
|
585 |
+
model.task = guess_model_task(model)
|
586 |
+
if not hasattr(model, 'stride'):
|
587 |
+
model.stride = torch.tensor([32.])
|
588 |
+
|
589 |
+
model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
|
590 |
+
|
591 |
+
# Module compatibility updates
|
592 |
+
for m in model.modules():
|
593 |
+
t = type(m)
|
594 |
+
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
|
595 |
+
m.inplace = inplace # torch 1.7.0 compatibility
|
596 |
+
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
597 |
+
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
598 |
+
|
599 |
+
# Return model and ckpt
|
600 |
+
return model, ckpt
|
601 |
+
|
602 |
+
|
603 |
+
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
604 |
+
# Parse a YOLO model.yaml dictionary into a PyTorch model
|
605 |
+
import ast
|
606 |
+
|
607 |
+
# Args
|
608 |
+
max_channels = float('inf')
|
609 |
+
nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales'))
|
610 |
+
depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
|
611 |
+
if scales:
|
612 |
+
scale = d.get('scale')
|
613 |
+
if not scale:
|
614 |
+
scale = tuple(scales.keys())[0]
|
615 |
+
LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
|
616 |
+
depth, width, max_channels = scales[scale]
|
617 |
+
|
618 |
+
if act:
|
619 |
+
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
620 |
+
if verbose:
|
621 |
+
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
622 |
+
|
623 |
+
if verbose:
|
624 |
+
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
625 |
+
ch = [ch]
|
626 |
+
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
627 |
+
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
628 |
+
m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
|
629 |
+
for j, a in enumerate(args):
|
630 |
+
if isinstance(a, str):
|
631 |
+
with contextlib.suppress(ValueError):
|
632 |
+
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
633 |
+
|
634 |
+
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
635 |
+
if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
|
636 |
+
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):
|
637 |
+
c1, c2 = ch[f], args[0]
|
638 |
+
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
639 |
+
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
640 |
+
|
641 |
+
args = [c1, c2, *args[1:]]
|
642 |
+
if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3):
|
643 |
+
args.insert(2, n) # number of repeats
|
644 |
+
n = 1
|
645 |
+
elif m is AIFI:
|
646 |
+
args = [ch[f], *args]
|
647 |
+
elif m in (HGStem, HGBlock):
|
648 |
+
c1, cm, c2 = ch[f], args[0], args[1]
|
649 |
+
args = [c1, cm, c2, *args[2:]]
|
650 |
+
if m is HGBlock:
|
651 |
+
args.insert(4, n) # number of repeats
|
652 |
+
n = 1
|
653 |
+
|
654 |
+
elif m is nn.BatchNorm2d:
|
655 |
+
args = [ch[f]]
|
656 |
+
elif m is Concat:
|
657 |
+
c2 = sum(ch[x] for x in f)
|
658 |
+
elif m in (Detect, Segment, Pose, RTDETRDecoder):
|
659 |
+
args.append([ch[x] for x in f])
|
660 |
+
if m is Segment:
|
661 |
+
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
662 |
+
else:
|
663 |
+
c2 = ch[f]
|
664 |
+
|
665 |
+
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
666 |
+
t = str(m)[8:-2].replace('__main__.', '') # module type
|
667 |
+
m.np = sum(x.numel() for x in m_.parameters()) # number params
|
668 |
+
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
669 |
+
if verbose:
|
670 |
+
LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
|
671 |
+
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
672 |
+
layers.append(m_)
|
673 |
+
if i == 0:
|
674 |
+
ch = []
|
675 |
+
ch.append(c2)
|
676 |
+
return nn.Sequential(*layers), sorted(save)
|
677 |
+
|
678 |
+
|
679 |
+
def yaml_model_load(path):
|
680 |
+
"""Load a YOLOv8 model from a YAML file."""
|
681 |
+
import re
|
682 |
+
|
683 |
+
path = Path(path)
|
684 |
+
if path.stem in (f'yolov{d}{x}6' for x in 'nsmlx' for d in (5, 8)):
|
685 |
+
new_stem = re.sub(r'(\d+)([nslmx])6(.+)?$', r'\1\2-p6\3', path.stem)
|
686 |
+
LOGGER.warning(f'WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.')
|
687 |
+
path = path.with_stem(new_stem)
|
688 |
+
|
689 |
+
unified_path = re.sub(r'(\d+)([nslmx])(.+)?$', r'\1\3', str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
|
690 |
+
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
|
691 |
+
d = yaml_load(yaml_file) # model dict
|
692 |
+
d['scale'] = guess_model_scale(path)
|
693 |
+
d['yaml_file'] = str(path)
|
694 |
+
return d
|
695 |
+
|
696 |
+
|
697 |
+
def guess_model_scale(model_path):
|
698 |
+
"""
|
699 |
+
Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale.
|
700 |
+
The function uses regular expression matching to find the pattern of the model scale in the YAML file name,
|
701 |
+
which is denoted by n, s, m, l, or x. The function returns the size character of the model scale as a string.
|
702 |
+
|
703 |
+
Args:
|
704 |
+
model_path (str | Path): The path to the YOLO model's YAML file.
|
705 |
+
|
706 |
+
Returns:
|
707 |
+
(str): The size character of the model's scale, which can be n, s, m, l, or x.
|
708 |
+
"""
|
709 |
+
with contextlib.suppress(AttributeError):
|
710 |
+
import re
|
711 |
+
return re.search(r'yolov\d+([nslmx])', Path(model_path).stem).group(1) # n, s, m, l, or x
|
712 |
+
return ''
|
713 |
+
|
714 |
+
|
715 |
+
def guess_model_task(model):
|
716 |
+
"""
|
717 |
+
Guess the task of a PyTorch model from its architecture or configuration.
|
718 |
+
|
719 |
+
Args:
|
720 |
+
model (nn.Module | dict): PyTorch model or model configuration in YAML format.
|
721 |
+
|
722 |
+
Returns:
|
723 |
+
(str): Task of the model ('detect', 'segment', 'classify', 'pose').
|
724 |
+
|
725 |
+
Raises:
|
726 |
+
SyntaxError: If the task of the model could not be determined.
|
727 |
+
"""
|
728 |
+
|
729 |
+
def cfg2task(cfg):
|
730 |
+
"""Guess from YAML dictionary."""
|
731 |
+
m = cfg['head'][-1][-2].lower() # output module name
|
732 |
+
if m in ('classify', 'classifier', 'cls', 'fc'):
|
733 |
+
return 'classify'
|
734 |
+
if m == 'detect':
|
735 |
+
return 'detect'
|
736 |
+
if m == 'segment':
|
737 |
+
return 'segment'
|
738 |
+
if m == 'pose':
|
739 |
+
return 'pose'
|
740 |
+
|
741 |
+
# Guess from model cfg
|
742 |
+
if isinstance(model, dict):
|
743 |
+
with contextlib.suppress(Exception):
|
744 |
+
return cfg2task(model)
|
745 |
+
|
746 |
+
# Guess from PyTorch model
|
747 |
+
if isinstance(model, nn.Module): # PyTorch model
|
748 |
+
for x in 'model.args', 'model.model.args', 'model.model.model.args':
|
749 |
+
with contextlib.suppress(Exception):
|
750 |
+
return eval(x)['task']
|
751 |
+
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
|
752 |
+
with contextlib.suppress(Exception):
|
753 |
+
return cfg2task(eval(x))
|
754 |
+
|
755 |
+
for m in model.modules():
|
756 |
+
if isinstance(m, Detect):
|
757 |
+
return 'detect'
|
758 |
+
elif isinstance(m, Segment):
|
759 |
+
return 'segment'
|
760 |
+
elif isinstance(m, Classify):
|
761 |
+
return 'classify'
|
762 |
+
elif isinstance(m, Pose):
|
763 |
+
return 'pose'
|
764 |
+
|
765 |
+
# Guess from model filename
|
766 |
+
if isinstance(model, (str, Path)):
|
767 |
+
model = Path(model)
|
768 |
+
if '-seg' in model.stem or 'segment' in model.parts:
|
769 |
+
return 'segment'
|
770 |
+
elif '-cls' in model.stem or 'classify' in model.parts:
|
771 |
+
return 'classify'
|
772 |
+
elif '-pose' in model.stem or 'pose' in model.parts:
|
773 |
+
return 'pose'
|
774 |
+
elif 'detect' in model.parts:
|
775 |
+
return 'detect'
|
776 |
+
|
777 |
+
# Unable to determine task from model
|
778 |
+
LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
|
779 |
+
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify', or 'pose'.")
|
780 |
+
return 'detect' # assume detect
|