kz919 commited on
Commit
8ec51c0
1 Parent(s): ac19806

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -80
app.py CHANGED
@@ -1,96 +1,212 @@
1
  import streamlit as st
 
 
2
  import chess
3
- import chess.svg
 
 
 
 
 
4
  from gradio_client import Client
5
- from cairosvg import svg2png
6
- from PIL import Image
7
- import io
8
 
9
- # Initialize the Gradio client
10
- client = Client("xianbao/SambaNova-fast")
 
11
 
12
- # Function to get AI move
13
- def get_ai_move(board):
14
- fen = board.fen()
15
- prompt = f"You are playing as Black in a chess game. The current board position in FEN notation is: {fen}. What is your next move? Please respond with the move in UCI notation (e.g., 'e2e4')."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- result = client.predict(
18
- message=prompt,
19
- system_message="You are a chess engine based on LLaMA 405B. Provide only the next move in UCI notation.",
20
- max_tokens=1024,
21
- temperature=0.6,
22
- top_p=0.9,
23
- top_k=50,
24
- api_name="/chat"
25
- )
26
 
27
- return result.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Function to convert SVG to PNG
30
- def svg_to_png(svg_string):
31
- png_bytes = svg2png(bytestring=svg_string.encode('utf-8'))
32
- return Image.open(io.BytesIO(png_bytes))
 
 
 
 
33
 
34
- # Streamlit app
35
- st.set_page_config(layout="wide")
36
- st.title("Chess against LLaMA 405B")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # Initialize session state
39
- if 'board' not in st.session_state:
40
- st.session_state.board = chess.Board()
 
 
 
 
 
41
 
42
- # Display the chessboard
43
- svg_board = chess.svg.board(board=st.session_state.board)
44
- st.image(svg_to_png(svg_board), width=400)
45
 
46
- # Input for user's move
47
- user_move = st.text_input("Enter your move (e.g., 'e2e4'):")
 
 
 
 
 
 
 
 
 
48
 
49
- if st.button("Make Move"):
50
- try:
51
- move = chess.Move.from_uci(user_move)
52
- if move in st.session_state.board.legal_moves:
53
- st.session_state.board.push(move)
54
-
55
- # Display updated board
56
- svg_board = chess.svg.board(board=st.session_state.board)
57
- st.image(svg_to_png(svg_board), width=400)
58
-
59
- if not st.session_state.board.is_game_over():
60
- with st.spinner("AI is thinking..."):
61
- ai_move = get_ai_move(st.session_state.board)
62
- ai_move_obj = chess.Move.from_uci(ai_move)
63
- st.session_state.board.push(ai_move_obj)
64
- st.write(f"AI's move: {ai_move}")
65
-
66
- # Display updated board after AI move
67
- svg_board = chess.svg.board(board=st.session_state.board)
68
- st.image(svg_to_png(svg_board), width=400)
69
-
70
- if st.session_state.board.is_game_over():
71
- st.write("Game Over!")
72
- st.write(f"Result: {st.session_state.board.result()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  else:
74
- st.write("Invalid move. Please try again.")
75
- except ValueError:
76
- st.write("Invalid input. Please enter a move in UCI notation (e.g., 'e2e4').")
77
-
78
- # Reset button
79
- if st.button("Reset Game"):
80
- st.session_state.board = chess.Board()
81
- st.experimental_rerun()
82
 
83
- # Display game status
84
- st.write(f"Current turn: {'White' if st.session_state.board.turn else 'Black'}")
85
- st.write(f"Fullmove number: {st.session_state.board.fullmove_number}")
86
- st.write(f"Is check? {'Yes' if st.session_state.board.is_check() else 'No'}")
87
- st.write(f"Is game over? {'Yes' if st.session_state.board.is_game_over() else 'No'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Hide the default Streamlit watermark
90
- hide_streamlit_style = """
91
- <style>
92
- #MainMenu {visibility: hidden;}
93
- footer {visibility: hidden;}
94
- </style>
95
- """
96
- st.markdown(hide_streamlit_style, unsafe_allow_html=True)
 
 
 
1
  import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ from streamlit.components.v1 import html
4
  import chess
5
+ import streamlit_scrollable_textbox as stx
6
+ from st_bridge import bridge
7
+ from modules.chess import Chess
8
+ from modules.utility import set_page
9
+ from modules.states import init_states
10
+ import datetime as dt
11
  from gradio_client import Client
12
+ import random
 
 
13
 
14
+ set_page(title='Chess vs LLaMA 3.1 405B', page_icon="♟️")
15
+ init_states()
16
+ st.session_state.board_width = 400
17
 
18
+ # Initialize the LLaMA 3.1 405B client
19
+ llama_client = Client("xianbao/SambaNova-fast")
20
+
21
+ # Initialize all session state variables
22
+ if 'player_color' not in st.session_state:
23
+ st.session_state.player_color = 'white'
24
+ if 'current_turn' not in st.session_state:
25
+ st.session_state.current_turn = 'white'
26
+ if 'game_started' not in st.session_state:
27
+ st.session_state.game_started = False
28
+ if 'curfen' not in st.session_state:
29
+ st.session_state.curfen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
30
+ if 'lastfen' not in st.session_state:
31
+ st.session_state.lastfen = None
32
+ if 'moves' not in st.session_state:
33
+ st.session_state.moves = {}
34
+ if 'curside' not in st.session_state:
35
+ st.session_state.curside = 'white'
36
+
37
+ def get_ai_move(fen):
38
+ board = chess.Board(fen)
39
+ legal_moves = list(board.legal_moves)
40
 
41
+ if not legal_moves:
42
+ return None
43
+
44
+ prompt = f"You are a chess engine. Given the following chess position in FEN notation: {fen}, suggest a good move. Respond with only the move in UCI notation (e.g., e2e4)."
 
 
 
 
 
45
 
46
+ for _ in range(3): # Try up to 3 times to get a valid move from the AI
47
+ try:
48
+ response = llama_client.predict(
49
+ message=prompt,
50
+ system_message="You are a chess engine assistant.",
51
+ max_tokens=10,
52
+ temperature=0.7, # Increased temperature for more varied moves
53
+ top_p=0.9,
54
+ top_k=50,
55
+ api_name="/chat"
56
+ )
57
+ move = chess.Move.from_uci(response.strip())
58
+ if move in legal_moves:
59
+ return move.uci()
60
+ except ValueError:
61
+ pass # If the AI produces an invalid move, we'll try again
62
+
63
+ # If the AI fails to produce a valid move after 3 attempts, choose a random legal move
64
+ return random.choice(legal_moves).uci()
65
 
66
+ def reset_game(player_color):
67
+ st.session_state.player_color = player_color
68
+ st.session_state.curfen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
69
+ st.session_state.moves = {}
70
+ st.session_state.current_turn = 'white'
71
+ st.session_state.game_started = True
72
+ st.session_state.lastfen = None
73
+ st.session_state.game_over = False
74
 
75
+ # If player chose black, make the first move for AI
76
+ if st.session_state.player_color == 'black':
77
+ ai_move = get_ai_move(st.session_state.curfen)
78
+ board = chess.Board(st.session_state.curfen)
79
+ if ai_move:
80
+ move = chess.Move.from_uci(ai_move)
81
+ board.push(move)
82
+ st.session_state.curfen = board.fen()
83
+ st.session_state.moves.update(
84
+ {
85
+ st.session_state.curfen : {
86
+ 'side': 'white',
87
+ 'curfen': st.session_state.curfen,
88
+ 'last_fen': "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
89
+ 'last_move': ai_move,
90
+ 'data': None,
91
+ 'timestamp': str(dt.datetime.now())
92
+ }
93
+ }
94
+ )
95
+ st.session_state.current_turn = 'black'
96
 
97
+ def check_game_end(board):
98
+ outcome = board.outcome()
99
+ if outcome:
100
+ st.session_state.game_over = True
101
+ if outcome.winner is None:
102
+ return "Draw"
103
+ return "White" if outcome.winner else "Black"
104
+ return None
105
 
106
+ st.title("Chess vs LLaMA 3.1 405B")
 
 
107
 
108
+ # Game controls
109
+ col1, col2, col3 = st.columns([1,1,1])
110
+ with col1:
111
+ player_color = st.selectbox("Choose your color", ['white', 'black'], key='color_select')
112
+ with col2:
113
+ if st.button('Start New Game', key='start_game'):
114
+ reset_game(player_color)
115
+ st.rerun()
116
+ with col3:
117
+ st.write(f"Current turn: {st.session_state.current_turn}")
118
+ st.write(f"Your color: {st.session_state.player_color}")
119
 
120
+ # Get the info from current board after the user made the move.
121
+ data = bridge("my-bridge")
122
+ if data is not None and st.session_state.game_started and not st.session_state.game_over:
123
+ st.session_state.lastfen = st.session_state.curfen
124
+ st.session_state.curfen = data['fen']
125
+ st.session_state.curside = data['move']['color'].replace('w','white').replace('b','black')
126
+ st.session_state.moves.update(
127
+ {
128
+ st.session_state.curfen : {
129
+ 'side':st.session_state.curside,
130
+ 'curfen':st.session_state.curfen,
131
+ 'last_fen':st.session_state.lastfen,
132
+ 'last_move':data['pgn'],
133
+ 'data': None,
134
+ 'timestamp': str(dt.datetime.now())
135
+ }
136
+ }
137
+ )
138
+ st.session_state.current_turn = 'white' if st.session_state.curside == 'black' else 'black'
139
+
140
+ board = chess.Board(st.session_state.curfen)
141
+ game_result = check_game_end(board)
142
+ if game_result:
143
+ st.success(f"Game Over! Winner: {game_result}")
144
+ elif st.session_state.current_turn != st.session_state.player_color:
145
+ # AI's turn
146
+ ai_move = get_ai_move(st.session_state.curfen)
147
+ if ai_move:
148
+ move = chess.Move.from_uci(ai_move)
149
+ board.push(move)
150
+ st.session_state.curfen = board.fen()
151
+ st.session_state.moves.update(
152
+ {
153
+ st.session_state.curfen : {
154
+ 'side': st.session_state.current_turn,
155
+ 'curfen': st.session_state.curfen,
156
+ 'last_fen': st.session_state.lastfen,
157
+ 'last_move': ai_move,
158
+ 'data': None,
159
+ 'timestamp': str(dt.datetime.now())
160
+ }
161
+ }
162
+ )
163
+ st.session_state.current_turn = st.session_state.player_color
164
+ game_result = check_game_end(board)
165
+ if game_result:
166
+ st.success(f"Game Over! Winner: {game_result}")
167
  else:
168
+ st.error("The AI couldn't make a move. The game may be over.")
 
 
 
 
 
 
 
169
 
170
+ # Main game display
171
+ cols = st.columns([3, 2])
172
+ with cols[0]:
173
+ if st.session_state.game_started:
174
+ puzzle = Chess(st.session_state.board_width, st.session_state.curfen)
175
+ components.html(
176
+ puzzle.puzzle_board(),
177
+ height=st.session_state.board_width + 75,
178
+ scrolling=False
179
+ )
180
+ board = chess.Board(st.session_state.curfen)
181
+
182
+ # Game status
183
+ status_col1, status_col2 = st.columns(2)
184
+ with status_col1:
185
+ st.write("Game Status:")
186
+ st.write(f"Check: {'Yes' if board.is_check() else 'No'}")
187
+ st.write(f"Checkmate: {'Yes' if board.is_checkmate() else 'No'}")
188
+ with status_col2:
189
+ st.write("\u200B") # Invisible character for alignment
190
+ st.write(f"Stalemate: {'Yes' if board.is_stalemate() else 'No'}")
191
+ st.write(f"Insufficient material: {'Yes' if board.is_insufficient_material() else 'No'}")
192
+
193
+ if st.session_state.game_over:
194
+ st.success(f"Game Over! Winner: {check_game_end(board)}")
195
+ else:
196
+ st.info("Welcome to Chess vs LLaMA 3.1 405B!")
197
+ st.write("To start a new game:")
198
+ st.write("1. Choose your color (white or black)")
199
+ st.write("2. Click 'Start New Game'")
200
+ st.write("3. Make your moves on the chess board")
201
+ st.write("Enjoy playing against the AI!")
202
 
203
+ with cols[1]:
204
+ if st.session_state.game_started:
205
+ st.subheader("Move History")
206
+ records = [
207
+ f"##### {value['timestamp'].split('.')[0]} \n {value['side']} - {value.get('last_move','')}"
208
+ for key, value in st.session_state['moves'].items()
209
+ ]
210
+ stx.scrollableTextbox('\n\n'.join(records), height = 400, border=True)
211
+ else:
212
+ st.image("https://upload.wikimedia.org/wikipedia/commons/6/6f/ChessSet.jpg", caption="Chess pieces", use_column_width=True)