chandrakalagowda commited on
Commit
86d1dd3
1 Parent(s): 9e6eff9

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +173 -0
  2. requirements.txt +127 -0
  3. reverse_image_search.zip +3 -0
  4. teddy.png +0 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import time
3
+ from zipfile import ZipFile
4
+
5
+ with ZipFile('reverse_image_search.zip', 'r') as zip:
6
+ # printing all the contents of the zip file
7
+ # extracting all the files
8
+ print('Extracting all the files now...')
9
+ zip.extractall()
10
+ print('Done!')
11
+
12
+ df = pd.read_csv('reverse_image_search.csv')
13
+ df.head()
14
+
15
+ import cv2
16
+ from towhee.types.image import Image
17
+
18
+ id_img = df.set_index('id')['path'].to_dict()
19
+ def read_images(results):
20
+ imgs = []
21
+ for re in results:
22
+ path = id_img[re.id]
23
+ imgs.append(Image(cv2.imread(path), 'BGR'))
24
+ return imgs
25
+
26
+
27
+ time.sleep(60)
28
+ from milvus import default_server
29
+ from pymilvus import connections, utility
30
+ default_server.start()
31
+ time.sleep(60)
32
+ connections.connect(host='127.0.0.1', port=default_server.listen_port)
33
+ time.sleep(60)
34
+ default_server.listen_port
35
+
36
+ time.sleep(20)
37
+ print(utility.get_server_version())
38
+
39
+
40
+ from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
41
+
42
+ def create_milvus_collection(collection_name, dim):
43
+ connections.connect(host='127.0.0.1', port='19530')
44
+
45
+ if utility.has_collection(collection_name):
46
+ utility.drop_collection(collection_name)
47
+
48
+ fields = [
49
+ FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False),
50
+ FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
51
+ ]
52
+ schema = CollectionSchema(fields=fields, description='text image search')
53
+ collection = Collection(name=collection_name, schema=schema)
54
+
55
+ # create IVF_FLAT index for collection.
56
+ index_params = {
57
+ 'metric_type':'L2',
58
+ 'index_type':"IVF_FLAT",
59
+ 'params':{"nlist":512}
60
+ }
61
+ collection.create_index(field_name="embedding", index_params=index_params)
62
+ return collection
63
+
64
+ collection = create_milvus_collection('text_image_search', 512)
65
+
66
+
67
+ from towhee import ops, pipe, DataCollection
68
+ import numpy as np
69
+
70
+ ###. This section needs to have the teddy.png in the folder. Else it will throw an error.
71
+ p = (
72
+ pipe.input('path')
73
+ .map('path', 'img', ops.image_decode.cv2('rgb'))
74
+ .map('img', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image'))
75
+ .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
76
+ .output('img', 'vec')
77
+ )
78
+
79
+ DataCollection(p('./teddy.png')).show()
80
+
81
+ p2 = (
82
+ pipe.input('text')
83
+ .map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text'))
84
+ .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
85
+ .output('text', 'vec')
86
+ )
87
+
88
+ DataCollection(p2("A teddybear on a skateboard in Times Square.")).show()
89
+
90
+ time.sleep(60)
91
+ collection = create_milvus_collection('text_image_search', 512)
92
+
93
+ def read_csv(csv_path, encoding='utf-8-sig'):
94
+ import csv
95
+ with open(csv_path, 'r', encoding=encoding) as f:
96
+ data = csv.DictReader(f)
97
+ for line in data:
98
+ yield int(line['id']), line['path']
99
+
100
+ p3 = (
101
+ pipe.input('csv_file')
102
+ .flat_map('csv_file', ('id', 'path'), read_csv)
103
+ .map('path', 'img', ops.image_decode.cv2('rgb'))
104
+ .map('img', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image'))
105
+ .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
106
+ .map(('id', 'vec'), (), ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='text_image_search'))
107
+ .output()
108
+ )
109
+
110
+ ret = p3('reverse_image_search.csv')
111
+
112
+
113
+ time.sleep(120)
114
+ collection.load()
115
+
116
+
117
+ time.sleep(120)
118
+ print('Total number of inserted data is {}.'.format(collection.num_entities))
119
+
120
+
121
+ import pandas as pd
122
+ import cv2
123
+
124
+ def read_image(image_ids):
125
+ df = pd.read_csv('reverse_image_search.csv')
126
+ id_img = df.set_index('id')['path'].to_dict()
127
+ imgs = []
128
+ decode = ops.image_decode.cv2('rgb')
129
+ for image_id in image_ids:
130
+ path = id_img[image_id]
131
+ imgs.append(decode(path))
132
+ return imgs
133
+
134
+
135
+ p4 = (
136
+ pipe.input('text')
137
+ .map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text'))
138
+ .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
139
+ .map('vec', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='text_image_search', limit=5))
140
+ .map('result', 'image_ids', lambda x: [item[0] for item in x])
141
+ .map('image_ids', 'images', read_image)
142
+ .output('text', 'images')
143
+ )
144
+
145
+ DataCollection(p4("A white dog")).show()
146
+ DataCollection(p4("A black dog")).show()
147
+
148
+ search_pipeline = (
149
+ pipe.input('text')
150
+ .map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text'))
151
+ .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
152
+ .map('vec', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='text_image_search', limit=5))
153
+ .map('result', 'image_ids', lambda x: [item[0] for item in x])
154
+ .output('image_ids')
155
+ )
156
+
157
+ def search(text):
158
+ df = pd.read_csv('reverse_image_search.csv')
159
+ id_img = df.set_index('id')['path'].to_dict()
160
+ imgs = []
161
+ image_ids = search_pipeline(text).to_list()[0][0]
162
+ return [id_img[image_id] for image_id in image_ids]
163
+
164
+
165
+ import gradio
166
+
167
+ interface = gradio.Interface(search,
168
+ gradio.inputs.Textbox(lines=1),
169
+ [gradio.outputs.Image(type="filepath", label=None) for _ in range(5)]
170
+ )
171
+
172
+
173
+ interface.launch(inline=True, share=True)
requirements.txt ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==5.0.1
5
+ annotated-types==0.5.0
6
+ anyio==3.7.0
7
+ appnope==0.1.3
8
+ asttokens==2.2.1
9
+ async-timeout==4.0.2
10
+ attrs==23.1.0
11
+ backcall==0.2.0
12
+ bleach==6.0.0
13
+ certifi==2023.5.7
14
+ charset-normalizer==3.1.0
15
+ click==8.1.3
16
+ comm==0.1.3
17
+ contourpy==1.1.0
18
+ cycler==0.11.0
19
+ debugpy==1.6.7
20
+ decorator==5.1.1
21
+ docutils==0.20.1
22
+ environs==9.5.0
23
+ executing==1.2.0
24
+ fastapi==0.99.1
25
+ ffmpy==0.3.0
26
+ filelock==3.12.2
27
+ fonttools==4.40.0
28
+ frozenlist==1.3.3
29
+ fsspec==2023.6.0
30
+ gradio==3.35.2
31
+ gradio_client==0.2.7
32
+ grpcio==1.53.0
33
+ h11==0.14.0
34
+ httpcore==0.17.2
35
+ httpx==0.24.1
36
+ huggingface-hub==0.15.1
37
+ idna==3.4
38
+ importlib-metadata==6.7.0
39
+ ipykernel==6.24.0
40
+ ipython==8.14.0
41
+ jaraco.classes==3.2.3
42
+ jedi==0.18.2
43
+ Jinja2==3.1.2
44
+ jsonschema==4.17.3
45
+ jupyter_client==8.3.0
46
+ jupyter_core==5.3.1
47
+ keyring==24.2.0
48
+ kiwisolver==1.4.4
49
+ linkify-it-py==2.0.2
50
+ markdown-it-py==2.2.0
51
+ MarkupSafe==2.1.3
52
+ marshmallow==3.19.0
53
+ matplotlib==3.7.1
54
+ matplotlib-inline==0.1.6
55
+ mdit-py-plugins==0.3.3
56
+ mdurl==0.1.2
57
+ milvus==2.2.10
58
+ more-itertools==9.1.0
59
+ mpmath==1.3.0
60
+ multidict==6.0.4
61
+ nest-asyncio==1.5.6
62
+ networkx==3.1
63
+ numpy==1.25.0
64
+ opencv-python==4.8.0.74
65
+ orjson==3.9.1
66
+ packaging==23.1
67
+ pandas==2.0.3
68
+ parso==0.8.3
69
+ pexpect==4.8.0
70
+ pickleshare==0.7.5
71
+ Pillow==10.0.0
72
+ pkginfo==1.9.6
73
+ platformdirs==3.8.0
74
+ prompt-toolkit==3.0.38
75
+ protobuf==4.23.3
76
+ psutil==5.9.5
77
+ ptyprocess==0.7.0
78
+ pure-eval==0.2.2
79
+ pydantic==1.10.10
80
+ pydantic_core==2.0.1
81
+ pydub==0.25.1
82
+ Pygments==2.15.1
83
+ pymilvus==2.2.11
84
+ pyparsing==3.1.0
85
+ pyrsistent==0.19.3
86
+ python-dateutil==2.8.2
87
+ python-dotenv==1.0.0
88
+ python-multipart==0.0.6
89
+ pytz==2023.3
90
+ PyYAML==6.0
91
+ pyzmq==25.1.0
92
+ readme-renderer==40.0
93
+ regex==2023.6.3
94
+ requests==2.31.0
95
+ requests-toolbelt==1.0.0
96
+ rfc3986==2.0.0
97
+ rich==13.4.2
98
+ safetensors==0.3.1
99
+ semantic-version==2.10.0
100
+ six==1.16.0
101
+ sniffio==1.3.0
102
+ stack-data==0.6.2
103
+ starlette==0.27.0
104
+ sympy==1.12
105
+ tabulate==0.9.0
106
+ tenacity==8.2.2
107
+ tokenizers==0.13.3
108
+ toolz==0.12.0
109
+ torch==2.0.1
110
+ torchvision==0.15.2
111
+ tornado==6.3.2
112
+ towhee==1.1.0
113
+ tqdm==4.65.0
114
+ traitlets==5.9.0
115
+ transformers==4.30.2
116
+ twine==4.0.2
117
+ typing_extensions==4.7.1
118
+ tzdata==2023.3
119
+ uc-micro-py==1.0.2
120
+ ujson==5.8.0
121
+ urllib3==2.0.3
122
+ uvicorn==0.22.0
123
+ wcwidth==0.2.6
124
+ webencodings==0.5.1
125
+ websockets==11.0.3
126
+ yarl==1.9.2
127
+ zipp==3.15.0
reverse_image_search.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:736813a3307070aae31c41fee6fad93fd4a86b2dcee012754f2c4b7cdb8b9464
3
+ size 125643445
teddy.png ADDED