Wakka2905 commited on
Commit
491ca71
1 Parent(s): bc8875e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +658 -3
app.py CHANGED
@@ -1,7 +1,662 @@
 
1
  import streamlit as st
2
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
- pipeline = pipeline(task="reinforcement-learning", model="Wakka2905/Dino-Ai-Model")
 
6
 
7
- st.title("Proyecto Final: Entrenamiento por Refuerzo del Dinosaurio de Chrome by Wakka")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
  import streamlit as st
3
+ import pygame
4
+ import os
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ import pandas as pd
8
+ import numpy as np
9
+ from collections import deque
10
+ import random
11
+ from typing import List
12
+ from argparse import Action
13
+ import random
14
+ import sys
15
+ from sqlalchemy import asc
16
+ import math
17
+ import time
18
+ from tqdm import tqdm
19
+ from datetime import datetime
20
 
21
 
22
+ SCREEN_HEIGHT = 600
23
+ SCREEN_WIDTH = 1100
24
 
25
+ INIT_GAME_SPEED = 14
26
+ X_POS_BG_INIT = 0
27
+ Y_POS_BG = 380
28
+
29
+ INIT_REPLAY_MEM_SIZE = 5_000
30
+ REPLAY_MEMORY_SIZE = 45_000
31
+ MODEL_NAME = "DINO"
32
+ MIN_REPLAY_MEMORY_SIZE = 1_000
33
+ MINIBATCH_SIZE = 64
34
+ DISCOUNT = 0.95
35
+ UPDATE_TARGET_THRESH = 5
36
+ #EPSILON_INIT = 0.45 epsilon inicial
37
+ EPSILON_INIT = 0.25 #modificamos para que sea menos exploratorio, menor epsilon menos exploratorio
38
+ #EPSILON_DECAY = 0.997 epsilon inicial
39
+ EPSILON_DECAY = 0.75 #modificamos para que sea menos exploratorio, menor epsilon menos exploratorio
40
+ NUM_EPISODES = 100
41
+ MIN_EPSILON = 0.05
42
+
43
+ RUNNING = [pygame.image.load(os.path.join("Assets/Dino", "DinoRun1.png")),
44
+ pygame.image.load(os.path.join("Assets/Dino", "DinoRun2.png"))]
45
+
46
+ DUCKING = [pygame.image.load(os.path.join("Assets/Dino", "DinoDuck1.png")),
47
+ pygame.image.load(os.path.join("Assets/Dino", "DinoDuck2.png"))]
48
+
49
+
50
+ JUMPING = pygame.image.load(os.path.join("Assets/Dino", "DinoJump.png"))
51
+
52
+ SMALL_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus1.png")),
53
+ pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus2.png")),
54
+ pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus3.png"))]
55
+
56
+
57
+ LARGE_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus1.png")),
58
+ pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus2.png")),
59
+ pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus3.png"))]
60
+
61
+ BIRD = [pygame.image.load(os.path.join("Assets/Bird", "Bird1.png")), pygame.image.load(os.path.join("Assets/Bird", "Bird2.png"))]
62
+
63
+ CLOUD = pygame.image.load(os.path.join("Assets/Other", "Cloud.png"))
64
+
65
+ BACKGROUND = pygame.image.load(os.path.join("Assets/Other", "Track.png"))
66
+
67
+ RUNNING = [pygame.image.load(os.path.join("Assets/Dino", "DinoRun1.png")),
68
+ pygame.image.load(os.path.join("Assets/Dino", "DinoRun2.png"))]
69
+
70
+ DUCKING = [pygame.image.load(os.path.join("Assets/Dino", "DinoDuck1.png")),
71
+ pygame.image.load(os.path.join("Assets/Dino", "DinoDuck2.png"))]
72
+
73
+
74
+ JUMPING = pygame.image.load(os.path.join("Assets/Dino", "DinoJump.png"))
75
+
76
+ SMALL_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus1.png")),
77
+ pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus2.png")),
78
+ pygame.image.load(os.path.join("Assets/Cactus", "SmallCactus3.png"))]
79
+
80
+
81
+ LARGE_CACTUS = [pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus1.png")),
82
+ pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus2.png")),
83
+ pygame.image.load(os.path.join("Assets/Cactus", "LargeCactus3.png"))]
84
+
85
+ BIRD = [pygame.image.load(os.path.join("Assets/Bird", "Bird1.png")), pygame.image.load(os.path.join("Assets/Bird", "Bird2.png"))]
86
+
87
+ CLOUD = pygame.image.load(os.path.join("Assets/Other", "Cloud.png"))
88
+
89
+ BACKGROUND = pygame.image.load(os.path.join("Assets/Other", "Track.png"))
90
+
91
+ class NeuralNetwork(nn.Module):
92
+ def __init__(self):
93
+ super(NeuralNetwork, self).__init__()
94
+ self.fc1 = nn.Linear(7, 4) # 7 input features, 4 output features
95
+ self.fc2 = nn.Linear(4, 3) # 4 input features, 3 output features
96
+
97
+ def forward(self, x):
98
+ x = torch.relu(self.fc1(x))
99
+ x = self.fc2(x)
100
+ return x
101
+
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #Para poder usar GPU
103
+
104
+ class DQNAgent:
105
+ def __init__(self):
106
+ self.model = NeuralNetwork().to(device) # Mover el modelo a la GPU si está disponible
107
+ self.target_model = NeuralNetwork().to(device) # Mover el modelo a la GPU si está disponible
108
+ self.target_model.load_state_dict(self.model.state_dict())
109
+ self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
110
+ self.loss_function = nn.MSELoss()
111
+
112
+ self.init_replay_memory = deque(maxlen=INIT_REPLAY_MEM_SIZE)
113
+ self.late_replay_memory = deque(maxlen=REPLAY_MEMORY_SIZE)
114
+ self.target_update_counter = 0
115
+ # Update the memory store
116
+ def update_replay_memory(self, transition):
117
+ # if len(self.replay_memory) > 50_000:
118
+ # self.replay_memory.clear()
119
+ if len(self.init_replay_memory) < INIT_REPLAY_MEM_SIZE:
120
+ self.init_replay_memory.append(transition)
121
+ else:
122
+ self.late_replay_memory.append(transition)
123
+
124
+ # Método get_qs dentro de la clase DQNAgent
125
+ def get_qs(self, state):
126
+ state_tensor = torch.Tensor(state).to(device) # Asegúrate de mover el tensor al dispositivo correcto
127
+ with torch.no_grad():
128
+ return self.model(state_tensor).cpu().numpy() # Luego mueve el resultado de vuelta a la CPU si es necesario
129
+
130
+ def train(self, terminal_state, step):
131
+ if len(self.init_replay_memory) < MIN_REPLAY_MEMORY_SIZE:
132
+ return
133
+
134
+ total_mem = list(self.init_replay_memory)
135
+ total_mem.extend(self.late_replay_memory)
136
+ minibatch = random.sample(total_mem, MINIBATCH_SIZE)
137
+
138
+ # Asegurarse de que los tensores estén en el dispositivo correcto
139
+ current_states = torch.Tensor([transition[0] for transition in minibatch]).to(device)
140
+ current_qs_list = self.model(current_states)
141
+ new_current_states = torch.Tensor([transition[3] for transition in minibatch]).to(device)
142
+ future_qs_list = self.target_model(new_current_states)
143
+
144
+ X = []
145
+ y = []
146
+
147
+ for index, (current_state, action, reward, new_current_state, done) in enumerate(minibatch):
148
+ if not done:
149
+ max_future_q = torch.max(future_qs_list[index])
150
+ new_q = reward + DISCOUNT * max_future_q
151
+ else:
152
+ new_q = reward
153
+
154
+ current_qs = current_qs_list[index]
155
+ current_qs[action] = new_q
156
+
157
+ X.append(current_state)
158
+ y.append(current_qs)
159
+
160
+ X = torch.tensor(np.array(X, dtype=np.float32)).to(device) # Mover X a la GPU
161
+ y = torch.tensor(np.array([y_item.detach().cpu().numpy() if isinstance(y_item, torch.Tensor) else y_item for y_item in y], dtype=np.float32)).to(device) # Mover y a la GPU
162
+
163
+ self.optimizer.zero_grad()
164
+ output = self.model(X) # X ya está en el dispositivo correcto
165
+ loss = self.loss_function(output, y) # y ya está en el dispositivo correcto
166
+ loss.backward()
167
+ self.optimizer.step()
168
+
169
+ if terminal_state:
170
+ self.target_update_counter += 1
171
+
172
+ if self.target_update_counter > UPDATE_TARGET_THRESH:
173
+ self.target_model.load_state_dict(self.model.state_dict())
174
+ self.target_update_counter = 0
175
+ # print(self.target_update_counter)
176
+
177
+ class Obstacle:
178
+ def __init__(self, image: List[pygame.Surface], type: int) -> None:
179
+ self.image = image
180
+ self.type = type
181
+ self.rect = self.image[self.type].get_rect()
182
+ self.rect.x = SCREEN_WIDTH
183
+
184
+ def update(self, obstacles: list, game_speed: int):
185
+ self.rect.x -= game_speed
186
+ if self.rect.x < -self.rect.width:
187
+ obstacles.pop()
188
+
189
+ def draw(self, SCREEN: pygame.Surface):
190
+ SCREEN.blit(self.image[self.type], self.rect)
191
+
192
+ class Dino(DQNAgent):
193
+ X_POS = 80
194
+ Y_POS = 310
195
+ Y_DUCK_POS = 340
196
+ JUMP_VEL = 8.5
197
+ #code here
198
+ def __init__(self) -> None:
199
+ #Initializing the images for the dino
200
+ self.duck_img = DUCKING
201
+ self.run_img = RUNNING
202
+ self.jump_img = JUMPING
203
+
204
+
205
+ #Initially the dino starts running
206
+ self.dino_duck = False
207
+ self.dino_run = True
208
+ self.dino_jump = False
209
+
210
+ self.step_index = 0
211
+ self.jump_vel = self.JUMP_VEL
212
+ self.image = self.run_img[0]
213
+ self.dino_rect = self.image.get_rect()
214
+
215
+ self.dino_rect.x = self.X_POS
216
+ self.dino_rect.y = self.Y_POS
217
+
218
+ self.score = 0
219
+
220
+ super().__init__()
221
+
222
+
223
+ # Update the Dino's state
224
+ def update(self, move: pygame.key.ScancodeWrapper):
225
+ if self.dino_duck:
226
+ self.duck()
227
+
228
+ if self.dino_jump:
229
+ self.jump()
230
+
231
+ if self.dino_run:
232
+ self.run()
233
+
234
+ if self.step_index >= 20:
235
+ self.step_index = 0
236
+
237
+
238
+ if move[pygame.K_UP] and not self.dino_jump:
239
+ self.dino_jump = True
240
+ self.dino_run = False
241
+ self.dino_duck = False
242
+
243
+ elif move[pygame.K_DOWN] and not self.dino_jump:
244
+ self.dino_duck = True
245
+ self.dino_run = False
246
+ self.dino_jump = False
247
+
248
+ elif not(self.dino_jump or move[pygame.K_DOWN]):
249
+ self.dino_run = True
250
+ self.dino_jump = False
251
+ self.dino_duck = False
252
+
253
+ def update_auto(self, move):
254
+ if self.dino_duck == True:
255
+ self.duck()
256
+
257
+ if self.dino_jump == True:
258
+ self.jump()
259
+
260
+ if self.dino_run == True:
261
+ self.run()
262
+
263
+ if self.step_index >= 20:
264
+ self.step_index = 0
265
+
266
+ if move == 0 and not self.dino_jump:
267
+ self.dino_jump = True
268
+ self.dino_run = False
269
+ self.dino_duck = False
270
+
271
+ elif move == 1 and not self.dino_jump:
272
+ self.dino_duck = True
273
+ self.dino_run = False
274
+ self.dino_jump = False
275
+
276
+ elif not(self.dino_jump or move == 1):
277
+ self.dino_run = True
278
+ self.dino_jump = False
279
+ self.dino_duck = False
280
+
281
+ def duck(self) -> None:
282
+ self.image = self.duck_img[self.step_index // 10]
283
+ self.dino_rect = self.image.get_rect()
284
+ self.dino_rect.x = self.X_POS
285
+ self.dino_rect.y = self.Y_DUCK_POS
286
+ self.step_index += 1
287
+
288
+ def run(self) -> None:
289
+ self.image = self.run_img[self.step_index // 10]
290
+ self.dino_rect = self.image.get_rect()
291
+ self.dino_rect.x = self.X_POS
292
+ self.dino_rect.y = self.Y_POS
293
+ self.step_index += 1
294
+
295
+
296
+ def jump(self) -> None:
297
+ self.image = self.jump_img
298
+ if self.dino_jump:
299
+ self.dino_rect.y -= self.jump_vel * 3
300
+ self.jump_vel -= 0.6
301
+
302
+ if self.jump_vel < -self.JUMP_VEL:
303
+ self.dino_jump = False
304
+ self.dino_run = True
305
+ self.jump_vel = self.JUMP_VEL
306
+
307
+ def draw(self, SCREEN: pygame.Surface):
308
+ SCREEN.blit(self.image, (self.dino_rect.x, self.dino_rect.y))
309
+
310
+ class LargeCactus(Obstacle):
311
+ def __init__(self, image: List[pygame.Surface]) -> None:
312
+ self.type = random.randint(0, 2)
313
+ super().__init__(image, self.type)
314
+ self.rect.y = 300
315
+
316
+
317
+ class SmallCactus(Obstacle):
318
+ def __init__(self, image: List[pygame.Surface]) -> None:
319
+ self.type = random.randint(0, 2)
320
+ super().__init__(image, self.type)
321
+ self.rect.y = 325
322
+
323
+ class Bird(Obstacle):
324
+ def __init__(self, image: List[pygame.Surface]) -> None:
325
+ self.type = 0
326
+ super().__init__(image, self.type)
327
+ self.rect.y = SCREEN_HEIGHT - 340
328
+ self.index = 0
329
+
330
+ def draw(self, SCREEN: pygame.Surface):
331
+ if self.index >= 19:
332
+ self.index = 0
333
+
334
+ SCREEN.blit(self.image[self.index // 10], self.rect)
335
+ self.index += 1
336
+
337
+ class Cloud:
338
+ def __init__(self) -> None:
339
+ self.x = SCREEN_WIDTH + random.randint(800, 1000)
340
+ self.y = random.randint(50, 100)
341
+ self.image = CLOUD
342
+ self.width = self.image.get_width()
343
+
344
+ def update(self, game_speed: int):
345
+ self.x -= game_speed
346
+ if self.x < -self.width:
347
+ self.x = SCREEN_WIDTH + random.randint(800, 1000)
348
+ self.y = random.randint(50, 100)
349
+
350
+
351
+ def draw(self, SCREEN: pygame.Surface):
352
+ SCREEN.blit(self.image, (self.x, self.y))
353
+
354
+ class Game:
355
+ def __init__(self, epsilon, load_model=False, model_path=None):
356
+ pygame.init()
357
+ self.SCREEN = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
358
+
359
+ self.obstacles = []
360
+
361
+ self.run = True
362
+
363
+ self.clock = pygame.time.Clock()
364
+
365
+ self.cloud = Cloud()
366
+
367
+ self.game_speed = INIT_GAME_SPEED
368
+
369
+ self.font = pygame.font.Font("freesansbold.ttf", 20)
370
+
371
+ self.dino = Dino()
372
+
373
+ # Cargar el modelo si se solicita
374
+ if load_model and model_path:
375
+ self.dino.model.load_state_dict(torch.load(model_path, map_location=device))
376
+
377
+ self.x_pos_bg = X_POS_BG_INIT
378
+
379
+ self.points = 0
380
+
381
+ self.epsilon = epsilon
382
+
383
+ self.ep_rewards = [-200]
384
+
385
+ self.high_score = 0 # Inicializa el high score con 0 o carga el high score existente de un archivo si lo prefieres
386
+
387
+ self.best_score = 0
388
+
389
+ def reset(self):
390
+ self.game_speed = INIT_GAME_SPEED
391
+ old_dino = self.dino
392
+ self.dino = Dino()
393
+ self.dino.init_replay_memory = old_dino.init_replay_memory
394
+ self.dino.late_replay_memory = old_dino.late_replay_memory
395
+ self.dino.target_update_counter = old_dino.target_update_counter
396
+
397
+ self.dino.model.load_state_dict(old_dino.model.state_dict())
398
+ self.dino.target_model.load_state_dict(old_dino.target_model.state_dict())
399
+
400
+ self.x_pos_bg = X_POS_BG_INIT
401
+ self.points = 0
402
+ self.SCREEN = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
403
+ self.clock = pygame.time.Clock()
404
+
405
+ def get_dist(self, pos_a: tuple, pos_b:tuple):
406
+ dx = pos_a[0] - pos_b[0]
407
+ dy = pos_a[1] - pos_b[1]
408
+
409
+ return math.sqrt(dx**2 + dy**2)
410
+
411
+ def update_background(self):
412
+ image_width = BACKGROUND.get_width()
413
+
414
+ self.SCREEN.blit(BACKGROUND, (self.x_pos_bg, Y_POS_BG))
415
+ self.SCREEN.blit(BACKGROUND, (self.x_pos_bg + image_width, Y_POS_BG))
416
+
417
+ if self.x_pos_bg <= -image_width:
418
+ self.SCREEN.blit(BACKGROUND, (self.x_pos_bg + image_width, Y_POS_BG))
419
+ self.x_pos_bg = 0
420
+
421
+ self.x_pos_bg -= self.game_speed
422
+ return self.x_pos_bg
423
+
424
+ def get_state(self):
425
+ state = []
426
+ state.append(self.dino.dino_rect.y / self.dino.Y_DUCK_POS + 10)
427
+ pos_a = (self.dino.dino_rect.x, self.dino.dino_rect.y)
428
+ bird = 0
429
+ cactus = 0
430
+ if len(self.obstacles) == 0:
431
+ dist = self.get_dist(pos_a, tuple([SCREEN_WIDTH + 10, self.dino.Y_POS])) / math.sqrt(SCREEN_HEIGHT**2 + SCREEN_WIDTH**2)
432
+ obs_height = 0
433
+ obj_width = 0
434
+ else:
435
+ dist = self.get_dist(pos_a, (self.obstacles[0].rect.midtop)) / math.sqrt(SCREEN_HEIGHT**2 + SCREEN_WIDTH**2)
436
+ obs_height = self.obstacles[0].rect.midtop[1] / self.dino.Y_DUCK_POS
437
+ obj_width = self.obstacles[0].rect.width / SMALL_CACTUS[2].get_rect().width
438
+ if self.obstacles[0].__class__ == SmallCactus(SMALL_CACTUS).__class__ or \
439
+ self.obstacles[0].__class__ == LargeCactus(LARGE_CACTUS).__class__:
440
+ cactus = 1
441
+ else:
442
+ bird = 1
443
+
444
+ state.append(dist)
445
+ state.append(obs_height)
446
+ state.append(self.game_speed / 24)
447
+ state.append(obj_width)
448
+ state.append(cactus)
449
+ state.append(bird)
450
+
451
+ return state
452
+
453
+
454
+ def update_score(self):
455
+ self.points += 1
456
+ if self.points % 200 == 0:
457
+ self.game_speed += 1
458
+
459
+ if self.points > self.high_score:
460
+ self.high_score = self.points
461
+
462
+ text = self.font.render(f"Points: {self.points} Highscore: {self.high_score}", True, (0, 0, 0))
463
+ textRect = text.get_rect()
464
+ textRect.center = (SCREEN_WIDTH - textRect.width // 2 - 10, 40)
465
+ self.SCREEN.blit(text, textRect)
466
+
467
+
468
+ def create_obstacle(self):
469
+ # bird_prob = random.randint(0, 15)
470
+ # cactus_prob = random.randint(0, 10)
471
+ # if bird_prob == 0:
472
+ # self.obstacles.append(Bird(BIRD))
473
+ # elif cactus_prob == 0:
474
+ # self.obstacles.append(SmallCactus(SMALL_CACTUS))
475
+ # elif cactus_prob == 1:
476
+ # self.obstacles.append(LargeCactus(LARGE_CACTUS))
477
+
478
+ obstacle_prob = random.randint(0, 50)
479
+ if obstacle_prob == 0:
480
+ self.obstacles.append(SmallCactus(SMALL_CACTUS))
481
+ elif obstacle_prob == 1:
482
+ self.obstacles.append(LargeCactus(LARGE_CACTUS))
483
+ elif obstacle_prob == 2 and self.points > 300:
484
+ self.obstacles.append(Bird(BIRD))
485
+
486
+ def update_game(self, moves, user_input=None):
487
+ self.dino.draw(self.SCREEN)
488
+ if user_input is not None:
489
+ self.dino.update(user_input)
490
+ else:
491
+ self.dino.update_auto(moves)
492
+
493
+ self.update_background()
494
+
495
+ self.cloud.draw(self.SCREEN)
496
+
497
+ self.cloud.update(self.game_speed)
498
+
499
+ self.update_score()
500
+
501
+ self.clock.tick(30)
502
+
503
+ # pygame.display.update()
504
+
505
+ def play_manual(self):
506
+
507
+ while self.run is True:
508
+ for event in pygame.event.get():
509
+ if event.type == pygame.QUIT:
510
+ sys.exit()
511
+
512
+ self.SCREEN.fill((255, 255, 255))
513
+ user_input = pygame.key.get_pressed()
514
+ # moves = []
515
+
516
+ if len(self.obstacles) == 0:
517
+ self.create_obstacle()
518
+
519
+ for obstacle in self.obstacles:
520
+ obstacle.draw(SCREEN=self.SCREEN)
521
+ obstacle.update(self.obstacles, self.game_speed)
522
+ if self.dino.dino_rect.colliderect(obstacle.rect):
523
+ self.dino.score = self.points
524
+ pygame.quit()
525
+ self.obstacles.pop()
526
+ print("Game over!")
527
+ return
528
+
529
+ self.update_game(user_input=user_input, moves=2)
530
+ pygame.display.update()
531
+
532
+
533
+ def play_auto(self):
534
+ try:
535
+ points_label = 0
536
+ for episode in tqdm(range(1, NUM_EPISODES + 1), ascii=True, unit='episodes'):
537
+ episode_reward = 0
538
+ step = 1
539
+ current_state = self.get_state()
540
+ self.run = True
541
+ while self.run is True:
542
+
543
+ for event in pygame.event.get():
544
+ if event.type == pygame.QUIT:
545
+ sys.exit()
546
+
547
+ self.SCREEN.fill((255, 255, 255))
548
+
549
+ if len(self.obstacles) == 0:
550
+ self.create_obstacle()
551
+
552
+ # if self.run == False:
553
+ # print(current_state)
554
+ # time.sleep(2)
555
+ # continue
556
+
557
+ if np.random.random() > self.epsilon:
558
+ action = self.dino.get_qs(torch.Tensor(current_state))
559
+ # print(action)
560
+ action = np.argmax(action)
561
+ # print(action)
562
+ else:
563
+ num = np.random.randint(0, 10)
564
+ if num == 0:
565
+ # print("yes")
566
+ action = num
567
+ elif num <= 3:
568
+ action = 1
569
+ else:
570
+ action = 2
571
+
572
+ self.update_game(moves=action)
573
+ # print(self.game_speed)
574
+ next_state = self.get_state()
575
+ reward = 0
576
+
577
+ for obstacle in self.obstacles:
578
+ obstacle.draw(SCREEN=self.SCREEN)
579
+ obstacle.update(self.obstacles, self.game_speed)
580
+ next_state = self.get_state()
581
+ if self.dino.dino_rect.x > obstacle.rect.x + obstacle.rect.width:
582
+ reward = 3
583
+
584
+ if action == 0 and obstacle.rect.x > SCREEN_WIDTH // 2:
585
+ reward = -1
586
+
587
+ if self.dino.dino_rect.colliderect(obstacle.rect):
588
+ self.dino.score = self.points
589
+ # pygame.quit()
590
+ self.obstacles.pop()
591
+ points_label = self.points
592
+ self.reset()
593
+ reward = -10
594
+ # print("Game over!")
595
+ self.run = False
596
+ break
597
+ # if reward != 0:
598
+ # print(reward > 0)
599
+
600
+ episode_reward += reward
601
+
602
+ self.dino.update_replay_memory(tuple([current_state, action, reward, next_state, self.run]))
603
+
604
+ self.dino.train( not self.run, step=step)
605
+
606
+ current_state = next_state
607
+
608
+ step += 1
609
+
610
+ # self.clock.tick(60)
611
+
612
+ #print(self.points)
613
+ #print(self.high_score)
614
+
615
+ # Al final de cada episodio, verifica si hay un nuevo mejor puntaje
616
+ if self.points > self.best_score:
617
+ self.best_score = self.points
618
+ # Este archivo se sobrescribirá con el último mejor modelo
619
+ self.best_model_filename = 'models/highscore/BestScore_model.pth'
620
+ torch.save(self.dino.model.state_dict(), self.best_model_filename)
621
+
622
+ pygame.display.update()
623
+
624
+
625
+ self.ep_rewards.append(episode_reward)
626
+
627
+ # Obtenemos la fecha y hora actual
628
+ current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
629
+
630
+ # Guardar el modelo cada 50 escenarios
631
+ if episode % 50 == 0:
632
+ filename = f'models/episodes/{points_label}_Points,Episode_{episode}_Date_{current_time}_model.pth'
633
+ torch.save(self.dino.model.state_dict(), filename)
634
+
635
+
636
+ if self.epsilon > MIN_EPSILON:
637
+ self.epsilon *= EPSILON_DECAY
638
+ if self.epsilon < MIN_EPSILON:
639
+ self.epsilon = 0
640
+ # print(self.epsilon)
641
+ else:
642
+ self.epsilon = max(MIN_EPSILON, self.epsilon)
643
+ # print(self.epsilon)
644
+ # print((self.dino.replay_memory))
645
+ finally:
646
+ # Este bloque se ejecutará incluso si se interrumpe el juego.
647
+ # Aquí duplicas el archivo del mejor puntaje alcanzado hasta ahora.
648
+ if hasattr(self, 'best_model_filename'):
649
+ current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
650
+ final_model_filename = f'models/highscore/{self.best_score}_BestScore_Final_{current_time}_model.pth'
651
+ import shutil
652
+ shutil.copy(self.best_model_filename, final_model_filename)
653
+ print(f"Modelo duplicado guardado como: {final_model_filename}")
654
+
655
+
656
+ # Streamlit UI
657
+ st.title('Juego del Dinosaurio con IA')
658
+
659
+ if st.button('Iniciar Juego con IA'):
660
+ model_path = 'models/highscore/4245_BestScore_Final_2023-12-10_18-43-53_model.pth' # Reemplaza con la ruta al modelo que quieras cargar
661
+ game = Game(EPSILON_INIT, load_model=True, model_path=model_path)
662
+ game.play_auto()