Spaces:
Sleeping
Sleeping
import time | |
import os | |
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | |
from PIL import Image | |
import numpy as np | |
import argparse | |
import faiss | |
import gradio as gr | |
import pandas as pd | |
import pickle | |
import cisen.utils.config as config | |
from cisen.utils.dataset import tokenize | |
from torchvision import transforms | |
from get_data_by_image_id import read_json | |
from cisen.model.segmenter import CISEN_rsvit_hug | |
transform = transforms.Compose([ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
]) | |
def get_parser(): | |
parser = argparse.ArgumentParser( | |
description='Pytorch Referring Expression Segmentation') | |
parser.add_argument('--config', | |
default='./cisen_r0.9_fpn.yaml', | |
type=str, | |
help='config file') | |
parser.add_argument('--opts', | |
default=None, | |
nargs=argparse.REMAINDER, | |
help='override some settings in the config.') | |
args = parser.parse_args() | |
assert args.config is not None | |
cfg = config.load_cfg_from_cfg_file(args.config) | |
if args.opts is not None: | |
cfg = config.merge_cfg_from_list(cfg, args.opts) | |
return cfg | |
args = get_parser() | |
data_dir = './LuojiaHOG(best)_.json' | |
imgs_folder = './image/' | |
feature_folder = './georsclip_21_r0.9_fpn/' | |
# image_id = 'sample44_1641.jpg' | |
# model_path = './rsvit.pth' | |
with open('image_features_best.pkl', 'rb') as f: | |
image_dict = pickle.load(f) | |
image_feat = np.array(list(image_dict.values())) | |
f.close() | |
with open('text_features_best.pkl', 'rb') as f: | |
text_dict = pickle.load(f) | |
text_feat = np.array(list(text_dict.values())) | |
f.close() | |
# with open('./LuojiaHOG(best)_.pkl', 'rb') as f: | |
# data_info = pickle.load(f) | |
# f.close() | |
sample_info = np.array(list(image_dict)) | |
data_info = read_json(data_dir) | |
config = {"embed_dim":512, "image_resolution":224, "vision_layers":12, "vision_width":768, | |
"vision_patch_size":32, "context_length":328, "txt_length":328, "vocab_size":49408, | |
"transformer_width":512, "transformer_heads":8, "transformer_layers":12, "patch_size":32, | |
"output_dim":512, "ratio":0.9, "emb_dim":768, "fpn_in":[512, 768, 768], "fpn_out":[768, 768, 768, 512]} | |
model = CISEN_rsvit_hug(**config) | |
model = model.from_pretrained("aleo1/cisen") | |
# img, img_, caption, image_feature, label, label_en, lat, lon = read_by_image_id(data_dir, imgs_folder, feature_folder) | |
# 准备数据 | |
# data = np.random.rand(1000, 512).astype(np.float32) # 生成随机的 1000 个向量,每个向量维度为 128 | |
# 创建索引 | |
image_index = faiss.IndexFlatL2(512) # 创建一个平坦索引,使用 L2 距离度量 | |
text_index = faiss.IndexFlatL2(512) | |
# 将数据添加到索引中 | |
image_index.add(image_feat) | |
text_index.add(text_feat) | |
#example | |
text1 = "A rectangular sports field with green artificial turf is visible. The field has white boundary lines and a bright blue surrounding track. Adjacent buildings with flat, gray roofs are visible. Roads with marked lanes run alongside the buildings. A red-roofed structure stands near the sports field. Vegetation includes small, scattered trees with green foliage. Cars are parked along the roads. Shadows cast by the buildings indicate different heights. Pedestrian pathways are present alongside the roads. The image contains a mix of recreational and residential zones. The layout suggests a planned urban environment." | |
text2 = "The picture shows a wetland full of diverse plants. The area has a network of waterways and thick vegetation, mainly tall reeds and cattails with slender, bamboo-like stalks. The scene is mostly green, with touches of blue and brown, creating a peaceful vibe. The land is mostly flat, allowing a wide view of the wetland. The image is taken from high above, giving a bird's-eye view of the waterways and aquatic plants in the wetland ecosystem." | |
text3 = "The residential area is depicted in the color remote sensing image with a bird's-eye view. The scene shows a heterogeneous mix of houses with varying shapes and sizes, spread across the area. The houses are painted in different colors, with some having white walls and red-tiled roofs, while others have blue or green exteriors. The residential area also contains several green spaces, including small front yards, larger parks, and gardens with various types of trees and shrubs. A broad road runs through the center of the residential area, connecting different parts of the community. The road is lined with trees on both sides and has a designated sidewalk for pedestrians. The image also captures various other elements of the urban landscape, including utility poles, streetlights, and a few commercial buildings on the outskirts of the residential area." | |
text4 = "The image depicts a nature reserve on an island, with a landscape dominated by sparse shrubs and meadows in the interior. The color of the image is predominantly green, with hints of brown and yellow, representing the different types of vegetation and soil. The reserve is characterized by rolling hills and gentle valleys, with some areas of flat terrain interspersed throughout. The landscape is dotted with trees, which are scattered randomly and have a relatively low density." | |
text5 = "The image shows a nature reserve on an island, featuring a landscape mainly covered with sparse shrubs and meadows. The dominant color is green, with touches of brown and yellow indicating various vegetation types and soil. The reserve has rolling hills and gentle valleys, along with some flat areas. Trees are scattered randomly across the landscape, with a relatively low density." | |
text6 = "The image describes a scene of a residential area with several houses situated next to a large, open stretch of land, which serves as a waste land. The houses are single-story structures with rectangular shapes and are evenly distributed across the scene. They have a pale blue color with a hint of white, which suggests that they are constructed using plastered walls. The waste land is covered in a mixture of brown and green colors, with patches of dry grass and scattered shrubs. The trees around the houses are slender and tall, with a lush green canopy that provides shade to the area. The scene is captured from a high altitude, offering a bird's-eye view of the area. The houses and trees are clearly distinguishable, and the waste land appears as a large, empty space in the center of the image." | |
text7 = "The neighborhood in the picture has streets arranged in a grid pattern with consistent, low-rise buildings. The buildings are mainly brown, red, and yellow, suggesting a blend of modern and traditional styles. The houses are surrounded by different kinds of greenery, such as small trees, bushes, and tall grasses." | |
text8 = "The color remote sensing image shows a residential area from a bird's-eye view, revealing a mix of houses of different shapes and sizes spread throughout the area. The houses are painted in various colors, with some having white walls and red-tiled roofs, while others feature blue or green exteriors. The neighborhood includes several green spaces, such as small front yards, larger parks, and gardens with different types of trees and shrubs. A wide road runs through the center, linking different parts of the community. This road is lined with trees and has sidewalks for pedestrians. The image also shows other urban features like utility poles, streetlights, and a few commercial buildings on the edges of the residential area." | |
text9 = "The image shows a barren and desolate landscape with little to no vegetation, dominated by a uniform color palette. The primary color is white, with some patches of gray and black. The terrain is mostly flat, with minimal changes in elevation. A white road is the only noticeable feature in the scene." | |
text10 = "The image shows a residential area with several single-story houses next to a large open stretch of wasteland. The houses are rectangular, evenly spaced, and have pale blue walls with hints of white, suggesting they are plastered. The wasteland is a mix of brown and green, featuring patches of dry grass and scattered shrubs. Surrounding the houses are tall, slender trees with lush green canopies providing shade. The scene is captured from a high altitude, giving a bird's-eye view where the houses, trees, and the central wasteland are clearly visible." | |
text11 = "The color remote sensing image shows an urban city street from a high altitude. The street is flanked by tall, sleek buildings featuring a mix of modern and traditional architecture, mostly in white and beige, with some more colorful facades. The street is busy with cars in white, black, silver, and gold, and pedestrians of diverse ethnicities wearing both modern and traditional clothing. Tall, lush trees with various shades of green line the street. The sky is bright blue with a few fluffy clouds. The image is high quality, with clear, visible details." | |
text12 = "The image displays a residential area with houses arranged in a grid-like pattern, each having a small yard. The houses are mostly uniform in size and shape, featuring pitched roofs and rectangular windows. They come in a variety of colors, from bright ones like yellow and pink to more neutral tones like white and gray. Trees of different sizes and shapes are scattered throughout the area. A parking lot next to the houses is mostly filled with cars, though some spots are empty. Sidewalks and streets connect the houses to other parts of the island." | |
image_folder = './example_image/' | |
image_files = [os.path.join(image_folder, filename) for filename in os.listdir(image_folder) if | |
filename.endswith('.jpg')] | |
image_list = [] | |
for image_file in image_files: | |
image_list.append([Image.open(image_file)]) | |
#search fun | |
def search(text_query, image_query, top_k: int = 10): | |
# 1. Embed the query as float32 | |
#将查询字符串编码为浮点数向量:使用预训练的语义文本嵌入模型,将输入的查询字符串编码为一个浮点数向量表示。 | |
start_time = time.time() | |
# query_embedding = model.encode(query) | |
if image_query is None: | |
text = tokenize(text_query, 328) | |
query_vector = model.text_encode(text) | |
index = text_index | |
else: | |
print(text_query) | |
print(image_query) | |
image_query = transform(Image.fromarray(image_query)) | |
query_vector = model.image_encode(image_query.unsqueeze(0)) | |
index = image_index | |
embed_time = time.time() - start_time | |
query_vector = np.array(query_vector.detach().numpy()) | |
# 2. Quantize the query to ubinary | |
#将查询向量量化为二进制向量:将浮点数向量转换为二进制量化向量,以便与已建立的二进制索引进行匹配。 | |
# start_time = time.time() | |
# query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary") | |
# quantize_time = time.time() - start_time | |
# 3. Search the binary index (either exact or approximate) | |
#使用二进制索引搜索:根据量化后的查询向量,在二进制索引中搜索与之相似的文档或文本。 | |
# index = binary_ivf if use_approx else binary_index | |
# index = binary_index | |
start_time = time.time() | |
# _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier) | |
_scores, binary_ids = index.search(query_vector, top_k) | |
binary_ids = binary_ids[0] | |
search_time = time.time() - start_time | |
# # 4. Load the corresponding int8 embeddings | |
# #加载相应的 int8 嵌入向量:根据搜索结果加载相应的 int8 嵌入向量,这些向量在预处理阶段已经被存储起来。 | |
# start_time = time.time() | |
# int8_embeddings = int8_view[binary_ids].astype(int) | |
# load_time = time.time() - start_time | |
# | |
# # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings | |
# #使用加载的 int8 嵌入向量和原始查询向量,重新评分 top_k * rescore_multiplier,以获取更精确的匹配结果。 | |
# start_time = time.time() | |
# scores = data @ int8_embeddings.T | |
# rescore_time = time.time() - start_time | |
# 6. Sort the scores and return the top_k | |
#根据得分对搜索结果进行排序,并返回前 top_k 个匹配结果,包括标题和文本内容。 | |
start_time = time.time() | |
indices = _scores.argsort()[::-1][:top_k] | |
top_k_indices = binary_ids[indices] | |
# 获得图像名 | |
info = list(sample_info[top_k_indices])[0] | |
top_k_scores = list(_scores)[0] | |
top_k_score = [np.round(value, 2) for value in top_k_scores] | |
top_k_labels, top_k_texts, lat, lon = zip( | |
*[(data_info[str(idx)]["label_name"], data_info[str(idx)]["description"], data_info[str(idx)]["lat"], | |
data_info[str(idx)]["lon"]) for idx in info] | |
) | |
# df = pd.DataFrame( | |
# {"Score": [torch.round(torch.tensor(value)*100)/100 for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts} | |
# ) | |
# 获取图像 | |
if text_query != None: | |
image_output = [Image.open(imgs_folder + img.replace('_','/')) for img in info] | |
else: | |
image_output = [] | |
df = pd.DataFrame( | |
{"Distance": top_k_score, 'Latitude' : lat, 'Longitude' : lon, "Description": top_k_texts} | |
) | |
df.round({"Distance":2, 'Latitude':4, 'Longitude':4}) | |
sort_time = time.time() - start_time | |
return df, image_output, { | |
"Embed Time": f"{embed_time:.4f} s", | |
# "Quantize Time": f"{quantize_time:.4f} s", | |
"Search Time": f"{search_time:.4f} s", | |
# "Load Time": f"{load_time:.4f} s", | |
# "Rescore Time": f"{rescore_time:.4f} s", | |
"Sort Time": f"{sort_time:.4f} s", | |
"Total Retrieval Time": f"{search_time + sort_time:.4f} s", | |
} | |
def img_search(image_query, top_k: int = 10): | |
# 1. Embed the query as float32 | |
#将查询字符串编码为浮点数向量:使用预训练的语义文本嵌入模型,将输入的查询字符串编码为一个浮点数向量表示。 | |
start_time = time.time() | |
# query_embedding = model.encode(query) | |
image_query = transform(Image.fromarray(image_query)) | |
query_vector = model.image_encode(image_query.unsqueeze(0)) | |
index = image_index | |
embed_time = time.time() - start_time | |
query_vector = np.array(query_vector.detach().numpy()) | |
# 2. Quantize the query to ubinary | |
#将查询向量量化为二进制向量:将浮点数向量转换为二进制量化向量,以便与已建立的二进制索引进行匹配。 | |
# start_time = time.time() | |
# query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary") | |
# quantize_time = time.time() - start_time | |
# 3. Search the binary index (either exact or approximate) | |
#使用二进制索引搜索:根据量化后的查询向量,在二进制索引中搜索与之相似的文档或文本。 | |
# index = binary_ivf if use_approx else binary_index | |
# index = binary_index | |
start_time = time.time() | |
# _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier) | |
_scores, binary_ids = index.search(query_vector, top_k) | |
binary_ids = binary_ids[0] | |
search_time = time.time() - start_time | |
# # 4. Load the corresponding int8 embeddings | |
# #加载相应的 int8 嵌入向量:根据搜索结果加载相应的 int8 嵌入向量,这些向量在预处理阶段已经被存储起来。 | |
# start_time = time.time() | |
# int8_embeddings = int8_view[binary_ids].astype(int) | |
# load_time = time.time() - start_time | |
# | |
# # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings | |
# #使用加载的 int8 嵌入向量和原始查询向量,重新评分 top_k * rescore_multiplier,以获取更精确的匹配结果。 | |
# start_time = time.time() | |
# scores = data @ int8_embeddings.T | |
# rescore_time = time.time() - start_time | |
# 6. Sort the scores and return the top_k | |
#根据得分对搜索结果进行排序,并返回前 top_k 个匹配结果,包括标题和文本内容。 | |
start_time = time.time() | |
indices = _scores.argsort()[::-1][:top_k] | |
top_k_indices = binary_ids[indices] | |
# 获得图像名 | |
info = list(sample_info[top_k_indices])[0] | |
top_k_scores = list(_scores)[0] | |
top_k_score = [np.round(value, 2) for value in top_k_scores] | |
top_k_labels, top_k_texts, lat, lon = zip( | |
*[(data_info[str(idx)]["label_name"], data_info[str(idx)]["description"], data_info[str(idx)]["lat"], | |
data_info[str(idx)]["lon"]) for idx in info] | |
) | |
# df = pd.DataFrame( | |
# {"Score": [torch.round(torch.tensor(value)*100)/100 for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts} | |
# ) | |
# 获取图像 | |
if text_query != None: | |
image_output = [Image.open(imgs_folder + img.replace('_','/')) for img in info] | |
else: | |
image_output = [] | |
df = pd.DataFrame( | |
{"Distance": top_k_score, 'Latitude' : lat, 'Longitude' : lon, "Description": top_k_texts} | |
) | |
df.round({"Distance":2, 'Latitude':4, 'Longitude':4}) | |
sort_time = time.time() - start_time | |
return df, image_output, { | |
"Embed Time": f"{embed_time:.4f} s", | |
# "Quantize Time": f"{quantize_time:.4f} s", | |
"Search Time": f"{search_time:.4f} s", | |
# "Load Time": f"{load_time:.4f} s", | |
# "Rescore Time": f"{rescore_time:.4f} s", | |
"Sort Time": f"{sort_time:.4f} s", | |
"Total Retrieval Time": f"{search_time + sort_time:.4f} s", | |
} | |
def txt_search(txt_query, top_k: int = 10): | |
# 1. Embed the query as float32 | |
# 将查询字符串编码为浮点数向量:使用预训练的语义文本嵌入模型,将输入的查询字符串编码为一个浮点数向量表示。 | |
start_time = time.time() | |
# query_embedding = model.encode(query) | |
text = tokenize(text_query, 328) | |
query_vector = model.text_encode(text) | |
index = text_index | |
embed_time = time.time() - start_time | |
query_vector = np.array(query_vector.detach().numpy()) | |
# 2. Quantize the query to ubinary | |
# 将查询向量量化为二进制向量:将浮点数向量转换为二进制量化向量,以便与已建立的二进制索引进行匹配。 | |
# start_time = time.time() | |
# query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary") | |
# quantize_time = time.time() - start_time | |
# 3. Search the binary index (either exact or approximate) | |
# 使用二进制索引搜索:根据量化后的查询向量,在二进制索引中搜索与之相似的文档或文本。 | |
# index = binary_ivf if use_approx else binary_index | |
# index = binary_index | |
start_time = time.time() | |
# _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier) | |
_scores, binary_ids = index.search(query_vector, top_k) | |
binary_ids = binary_ids[0] | |
search_time = time.time() - start_time | |
# # 4. Load the corresponding int8 embeddings | |
# #加载相应的 int8 嵌入向量:根据搜索结果加载相应的 int8 嵌入向量,这些向量在预处理阶段已经被存储起来。 | |
# start_time = time.time() | |
# int8_embeddings = int8_view[binary_ids].astype(int) | |
# load_time = time.time() - start_time | |
# | |
# # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings | |
# #使用加载的 int8 嵌入向量和原始查询向量,重新评分 top_k * rescore_multiplier,以获取更精确的匹配结果。 | |
# start_time = time.time() | |
# scores = data @ int8_embeddings.T | |
# rescore_time = time.time() - start_time | |
# 6. Sort the scores and return the top_k | |
# 根据得分对搜索结果进行排序,并返回前 top_k 个匹配结果,包括标题和文本内容。 | |
start_time = time.time() | |
indices = _scores.argsort()[::-1][:top_k] | |
top_k_indices = binary_ids[indices] | |
# 获得图像名 | |
info = list(sample_info[top_k_indices])[0] | |
top_k_scores = list(_scores)[0] | |
top_k_score = [np.round(value, 2) for value in top_k_scores] | |
top_k_labels, top_k_texts, lat, lon = zip( | |
*[(data_info[str(idx)]["label_name"], data_info[str(idx)]["description"], data_info[str(idx)]["lat"], | |
data_info[str(idx)]["lon"]) for idx in info] | |
) | |
# df = pd.DataFrame( | |
# {"Score": [torch.round(torch.tensor(value)*100)/100 for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts} | |
# ) | |
# 获取图像 | |
if text_query != None: | |
image_output = [Image.open(imgs_folder + img.replace('_', '/')) for img in info] | |
else: | |
image_output = [] | |
df = pd.DataFrame( | |
{"Distance": top_k_score, 'Latitude': lat, 'Longitude': lon, "Description": top_k_texts} | |
) | |
df.round({"Distance": 2, 'Latitude': 4, 'Longitude': 4}) | |
sort_time = time.time() - start_time | |
return df, image_output, { | |
"Embed Time": f"{embed_time:.4f} s", | |
# "Quantize Time": f"{quantize_time:.4f} s", | |
"Search Time": f"{search_time:.4f} s", | |
# "Load Time": f"{load_time:.4f} s", | |
# "Rescore Time": f"{rescore_time:.4f} s", | |
"Sort Time": f"{sort_time:.4f} s", | |
"Total Retrieval Time": f"{search_time + sort_time:.4f} s", | |
} | |
def update_visible(choice): | |
if choice == True: | |
return gr.Textbox( | |
label="Text query for remote sensing images", | |
placeholder="Enter a query to search for relevant images.", | |
visible=True, | |
interactive=True | |
), gr.Image( | |
label="Upload an image", | |
visible=False | |
) | |
elif choice == False: | |
return gr.Textbox( | |
label="Text query for remote sensing images", | |
placeholder="Enter a query to search for relevant images.", | |
visible=False | |
), gr.Image( | |
label="Upload an image", | |
visible=True, | |
interactive=True | |
) | |
else: | |
return gr.Textbox( | |
label="Text query for remote sensing images", | |
placeholder="Enter a query to search for relevant images.", | |
visible=True | |
), gr.Image( | |
label="Upload an image", | |
visible=True, | |
interactive=True | |
) | |
with gr.Blocks(title="Image-Text Retrieval") as demo: | |
# gr.Markdown( | |
# """ | |
# ## Quantized Retrieval - Binary Search with Scalar (int8) Rescoring | |
# This demo showcases retrieval using [quantized embeddings](https://huggingface.co/blog/embedding-quantization) on a CPU. The corpus consists of 41 million texts from Wikipedia articles. | |
# | |
# <details><summary>Click to learn about the retrieval process</summary> | |
# | |
# Details: | |
# 1. The query is embedded using the [`mixedbread-ai/mxbai-embed-large-v1`](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1) SentenceTransformer model. | |
# 2. The query is quantized to binary using the `quantize_embeddings` function from the SentenceTransformers library. | |
# 3. A binary index (41M binary embeddings; 5.2GB of memory/disk space) is searched using the quantized query for the top 40 documents. | |
# 4. The top 40 documents are loaded on the fly from an int8 index on disk (41M int8 embeddings; 0 bytes of memory, 47.5GB of disk space). | |
# 5. The top 40 documents are rescored using the float32 query and the int8 embeddings to get the top 10 documents. | |
# 6. The top 10 documents are sorted by score and displayed. | |
# | |
# This process is designed to be memory efficient and fast, with the binary index being small enough to fit in memory and the int8 index being loaded as a view to save memory. | |
# In total, this process requires keeping 1) the model in memory, 2) the binary index in memory, and 3) the int8 index on disk. With a dimensionality of 1024, | |
# we need `1024 / 8 * num_docs` bytes for the binary index and `1024 * num_docs` bytes for the int8 index. | |
# | |
# This is notably cheaper than doing the same process with float32 embeddings, which would require `4 * 1024 * num_docs` bytes of memory/disk space for the float32 index, i.e. 32x as much memory and 4x as much disk space. | |
# Additionally, the binary index is much faster (up to 32x) to search than the float32 index, while the rescoring is also extremely efficient. In conclusion, this process allows for fast, scalable, cheap, and memory-efficient retrieval. | |
# | |
# Feel free to check out the [code for this demo](https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/blob/main/app.py) to learn more about how to apply this in practice. | |
# | |
# Notes: | |
# - The approximate search index (a binary Inverted File Index (IVF)) is in beta and has not been trained with a lot of data. A better IVF index will be released soon. | |
# | |
# </details> | |
# """ | |
# ) | |
# 搜索索引选择:一个单选按钮组,允许用户选择是使用精确搜索还是近似搜索。 | |
search_index = gr.Radio( | |
choices=[("Examples", None), ("Image-to-Text", False), ("Text-to-Image", True)], | |
value=None, | |
label="Search Index", | |
) | |
# 查询输入框:一个文本框,允许用户输入查询字符串。用户可以在这里输入想要检索的内容。 | |
text_query = gr.Textbox( | |
label="Text query for remote sensing images", | |
placeholder="Enter a query to search for relevant images.", | |
visible=True, | |
interactive=True | |
) | |
#图像输入框:一个文本框,允许用户输入图像。用户可以在这里输入想要检索的图像。 | |
image_query = gr.Image( | |
label="Upload an image", | |
visible=True, | |
interactive=True | |
) | |
search_index.change(update_visible, search_index, [text_query, image_query]) | |
#检索参数设置:两个滑动条,用于设置检索参数。一个用于设置要检索的数量,另一个用于设置重新评分倍数。 | |
with gr.Row(): | |
with gr.Column(scale=2): | |
top_k = gr.Slider( | |
minimum=10, | |
maximum=100, | |
step=5, | |
value=10, | |
interactive=True, | |
label="Number of images/texts to retrieve", | |
info="Number of images/texts to retrieve", | |
) | |
with gr.Column(scale=2): | |
json = gr.JSON(label='retrieval time') | |
# rescore_multiplier = gr.Slider( | |
# minimum=1, | |
# maximum=10, | |
# step=1, | |
# value=1, | |
# interactive=True, | |
# label="Rescore multiplier", | |
# info="Search for `rescore_multiplier` as many documents to rescore", | |
# ) | |
#搜索按钮:一个按钮,当用户点击时会触发检索操作。 | |
with gr.Row(): | |
search_button = gr.Button(value="Search", variant='primary') | |
clear_button = gr.ClearButton(value='Clear Before Next Search') | |
#输出结果:一个数据框,用于显示检索结果。结果包括得分、标题和文本内容。 | |
with gr.Column(): | |
output = gr.Dataframe(headers=["Distance", "Latitude", "Longitude", "Description"], label="Text outputs") | |
#输出图像 | |
with gr.Row(): | |
image_output = gr.Gallery(label="Image outputs") | |
# def update_layout(): | |
# if search_index.value: | |
# return [search_index, text_query, top_k, rescore_multiplier] | |
# else: | |
# return [search_index, image_query, top_k, rescore_multiplier] | |
inputs = [search_index, text_query, image_query, top_k] | |
outputs = [output, json, image_output] | |
# exp_txt = gr.Examples(examples=[[text1, None], [text2, None], [text3, None], [text4, None], [text5, None], [text6, None], [text7, None], [text8, None], [text9, None], [text10, None], [text11, None], [text12, None]], | |
# inputs=[text_query, image_query, top_k], | |
# outputs=[output, image_output, json], fn=search, run_on_click=False, examples_per_page=4, label= "Text examples") | |
exp_txt = gr.Examples(examples=[[text1], [text2], [text3], [text4], [text5], [text6], [text7], [text8], [text9], [text10], [text11], [text12]], | |
inputs=[text_query, top_k], | |
outputs=[output, image_output, json], fn=txt_search, run_on_click=True, examples_per_page=4, label= "Text examples", cache_examples='lazy') | |
exp_img = gr.Examples(examples=image_list, inputs=[image_query, top_k], | |
outputs=[output, image_output, json], fn=img_search, run_on_click=True, examples_per_page=4, label="Image examples", cache_examples='lazy') | |
search_button.click(search, inputs=[text_query, image_query, top_k], outputs=[output, image_output, json]) | |
clear_button.add(components=[text_query, image_query, output, image_output, json]) | |
demo.queue() | |
demo.launch() | |