File size: 5,328 Bytes
7c1eee1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8f1d3f
 
 
 
7c1eee1
f8f1d3f
 
3e3ca46
f8f1d3f
 
 
7c1eee1
 
f8f1d3f
7c1eee1
f8f1d3f
7c1eee1
 
f8f1d3f
 
 
7c1eee1
 
 
f8f1d3f
7c1eee1
 
 
 
 
 
f8f1d3f
 
 
7c1eee1
f8f1d3f
 
 
 
 
 
7c1eee1
f8f1d3f
7c1eee1
 
f8f1d3f
7c1eee1
 
 
 
f8f1d3f
7c1eee1
 
f8f1d3f
7c1eee1
 
f8f1d3f
7c1eee1
f8f1d3f
 
7c1eee1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import concurrent.futures 
import random
import gradio as gr
import requests
import io, base64, json
# import spaces
from PIL import Image

from .model_config import model_config
from .model_worker import BaseModelWorker

class ModelManager:
    def __init__(self):
        self.models_config = model_config
        self.models_worker: list[BaseModelWorker] = {}

        self.build_model_workers()

    def build_model_workers(self):
        for cfg in self.models_config.values():
            worker = BaseModelWorker(cfg.model_name, cfg.i2s_model, cfg.online_model, cfg.model_path)
            self.models_worker[cfg.model_name] = worker
    
    def get_all_models(self):
        models = []
        for model_name in self.models_config.keys():
            models.append(model_name)
        return models
    
    def get_t2s_models(self):
        models = []
        for cfg in self.models_config.values():
            if not cfg.i2s_model:
                models.append(cfg.model_name)
        return models
    
    def get_i2s_models(self):
        models = []
        for cfg in self.models_config.values():
            if cfg.i2s_model:
                models.append(cfg.model_name)
        return models

    def get_online_models(self):
        models = []
        for cfg in self.models_config.values():
            if cfg.online_model:
                models.append(cfg.model_name)
        return models

    def get_models(self, i2s_model:bool, online_model:bool):
        models = []
        for cfg in self.models_config.values():
            if cfg.i2s_model==i2s_model and cfg.online_model==online_model:
                models.append(cfg.model_name)
        return models
    
    def check_online(self, name):
        worker = self.models_worker[name]
        if not worker.online_model:
            return 

    # @spaces.GPU(duration=120)
    def inference(self, 
                  prompt, model_name,
                  offline=False, offline_idx=None):
        result = None
        worker = self.models_worker[model_name]

        if offline:
            result = worker.load_offline(offline_idx)
        if not offline or result == None:
            if worker.check_online():
                result = worker.inference(prompt)
        return result
    
    def render(self, shape, model_name):
        worker = self.models_worker[model_name]
        result = worker.render(shape)
        return result
    
    def inference_parallel(self, 
                           prompt, model_A, model_B, 
                           offline=False, offline_idx=None):
        results = []
        model_names = [model_A, model_B]
        with concurrent.futures.ThreadPoolExecutor() as executor:
            future_to_result = {executor.submit(self.inference, prompt, model, offline, offline_idx): model 
                                for model in model_names}
            for future in concurrent.futures.as_completed(future_to_result):
                result = future.result()
                results.append(result)
        return results[0], results[1]

    def inference_parallel_anony(self, 
                                 prompt, model_A, model_B, 
                                 i2s_model: bool, offline: bool =False, offline_idx: int =None):
        if model_A == model_B == "":
            if offline and i2s_model:
                model_A, model_B = random.sample(self.get_i2s_models(), 2)
            elif offline and not i2s_model:
                model_A, model_B = random.sample(self.get_t2s_models(), 2)
            else:
                model_A, model_B = random.sample(self.get_models(i2s_model=i2s_model, online_model=True), 2)
        model_names = [model_A, model_B]

        results = []
        with concurrent.futures.ThreadPoolExecutor() as executor:
            future_to_result = {executor.submit(self.inference, prompt, model, offline, offline_idx): model 
                                for model in model_names}
            for future in concurrent.futures.as_completed(future_to_result):
                result = future.result()
                results.append(result)
        return results[0], results[1], model_A, model_B
    
    
    def render_parallel(self, shape_A, model_A, shape_B, model_B):
        results = []
        model_names = [model_A, model_B]
        shapes = [shape_A, shape_B]
        with concurrent.futures.ThreadPoolExecutor() as executor:
            future_to_result = {executor.submit(self.render, shape, model): model 
                                for model, shape in zip(model_names, shapes)}
            for future in concurrent.futures.as_completed(future_to_result):
                result = future.result()
                results.append(result)
        return results[0], results[1]
    
    # def i2s_inference_parallel(self, image, model_A, model_B):
    #     results = []
    #     model_names = [model_A, model_B]
    #     with concurrent.futures.ThreadPoolExecutor() as executor:
    #         future_to_result = {executor.submit(self.inference, image, model): model 
    #                             for model in model_names}
    #         for future in concurrent.futures.as_completed(future_to_result):
    #             result = future.result()
    #             results.append(result)
    #     return results[0], results[1]