swap1411 commited on
Commit
d74ed39
·
verified ·
1 Parent(s): 5728a32

Upload 7 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,37 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ stockfish-windows-x86-64-avx2.exe filter=lfs diff=lfs merge=lfs -text
37
+ stockfish-ubuntu-x86-64-sse41-popcnt filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,12 @@
1
- ---
2
- title: ChessVision AI
3
- emoji:
4
- colorFrom: green
5
- colorTo: purple
6
- sdk: streamlit
7
- sdk_version: 1.41.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Outplay your chess opponent-------Snap and get the best move
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: ChessVisionAi
3
+ emoji: 📉
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.41.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import cv2
3
+ from stockfish import Stockfish
4
+ import os
5
+ import numpy as np
6
+ import streamlit as st
7
+ import requests
8
+
9
+
10
+ # Constants
11
+ FEN_MAPPING = {
12
+ "black-pawn": "p", "black-rook": "r", "black-knight": "n", "black-bishop": "b", "black-queen": "q", "black-king": "k",
13
+ "white-pawn": "P", "white-rook": "R", "white-knight": "N", "white-bishop": "B", "white-queen": "Q", "white-king": "K"
14
+ }
15
+ GRID_BORDER = 10 # Border size in pixels
16
+ GRID_SIZE = 204 # Effective grid size (10px to 214px)
17
+ BLOCK_SIZE = GRID_SIZE // 8 # Each block is ~25px
18
+ X_LABELS = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
19
+ Y_LABELS = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
20
+
21
+ # Functions
22
+ def get_grid_coordinate(pixel_x, pixel_y):
23
+ """
24
+ Function to determine the grid coordinate of a pixel, considering a 10px border and
25
+ the grid where bottom-left is (a, 1) and top-left is (h, 8).
26
+ """
27
+ # Grid settings
28
+ border = 10 # 10px border
29
+ grid_size = 204 # Effective grid size (10px to 214px)
30
+ block_size = grid_size // 8 # Each block is ~25px
31
+
32
+ x_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # Labels for x-axis (a to h)
33
+ y_labels = [8, 7, 6, 5, 4, 3, 2, 1] # Reversed labels for y-axis (8 to 1)
34
+
35
+ # Adjust pixel_x and pixel_y by subtracting the border (grid starts at pixel 10)
36
+ adjusted_x = pixel_x - border
37
+ adjusted_y = pixel_y - border
38
+
39
+ # Check bounds
40
+ if adjusted_x < 0 or adjusted_y < 0 or adjusted_x >= grid_size or adjusted_y >= grid_size:
41
+ return "Pixel outside grid bounds"
42
+
43
+ # Determine the grid column and row
44
+ x_index = adjusted_x // block_size
45
+ y_index = adjusted_y // block_size
46
+
47
+ if x_index < 0 or x_index >= len(x_labels) or y_index < 0 or y_index >= len(y_labels):
48
+ return "Pixel outside grid bounds"
49
+
50
+ # Convert indices to grid coordinates
51
+ x_index = adjusted_x // block_size # Determine the column index (0-7)
52
+ y_index = adjusted_y // block_size # Determine the row index (0-7)
53
+
54
+ # Convert row index to the correct label, with '8' at the bottom
55
+ y_labeld = y_labels[y_index] # Correct index directly maps to '8' to '1'
56
+ x_label = x_labels[x_index]
57
+ y_label = 8 - y_labeld + 1
58
+
59
+ return f"{x_label}{y_label}"
60
+
61
+ def predict_next_move(fen, stockfish):
62
+ """
63
+ Predict the next move using Stockfish.
64
+ """
65
+ if stockfish.is_fen_valid(fen):
66
+ stockfish.set_fen_position(fen)
67
+ else:
68
+ return "Invalid FEN notation!"
69
+
70
+ best_move = stockfish.get_best_move()
71
+ ans = transform_string(best_move)
72
+ return f"The predicted next move is: {ans}" if best_move else "No valid move found (checkmate/stalemate)."
73
+
74
+
75
+
76
+ # def download_stockfish():
77
+ # url = "https://drive.google.com/file/d/18pkwBVc13fgKP3LzrTHE4yzhjyGJexlR/view?usp=sharing" # Replace with the actual link
78
+ # file_name = "stockfish-windows-x86-64-avx2.exe"
79
+
80
+ # if not os.path.exists(file_name):
81
+ # print(f"Downloading {file_name}...")
82
+ # response = requests.get(url, stream=True)
83
+ # with open(file_name, "wb") as file:
84
+ # for chunk in response.iter_content(chunk_size=1024):
85
+ # if chunk:
86
+ # file.write(chunk)
87
+ # print(f"{file_name} downloaded successfully.")
88
+
89
+ def process_image(image_path):
90
+ # Ensure output directory exists
91
+ if not os.path.exists('output'):
92
+ os.makedirs('output')
93
+
94
+ # Load the segmentation model
95
+ segmentation_model = YOLO("segmentation.pt")
96
+
97
+ # Run inference to get segmentation results
98
+ results = segmentation_model.predict(
99
+ source=image_path,
100
+ conf=0.8 # Confidence threshold
101
+ )
102
+
103
+ # Initialize variables for the segmented mask and bounding box
104
+ segmentation_mask = None
105
+ bbox = None
106
+
107
+ for result in results:
108
+ if result.boxes.conf[0] >= 0.8: # Filter results by confidence
109
+ segmentation_mask = result.masks.data.cpu().numpy().astype(np.uint8)[0]
110
+ bbox = result.boxes.xyxy[0].cpu().numpy() # Get the bounding box coordinates
111
+ break
112
+
113
+ if segmentation_mask is None:
114
+ print("No segmentation mask with confidence above 0.8 found.")
115
+ return None
116
+
117
+ # Load the image
118
+ image = cv2.imread(image_path)
119
+
120
+ # Resize segmentation mask to match the input image dimensions
121
+ segmentation_mask_resized = cv2.resize(segmentation_mask, (image.shape[1], image.shape[0]))
122
+
123
+ # Extract bounding box coordinates
124
+ if bbox is not None:
125
+ x1, y1, x2, y2 = bbox
126
+ # Crop the segmented region based on the bounding box
127
+ cropped_segment = image[int(y1):int(y2), int(x1):int(x2)]
128
+
129
+ # Save the cropped segmented image
130
+ cropped_image_path = 'output/cropped_segment.jpg'
131
+ cv2.imwrite(cropped_image_path, cropped_segment)
132
+ print(f"Cropped segmented image saved to {cropped_image_path}")
133
+
134
+ st.image(cropped_segment, caption="Uploaded Image", use_column_width=True)
135
+ # Return the cropped image
136
+ return cropped_segment
137
+
138
+ def transform_string(input_str):
139
+ # Remove extra spaces and convert to lowercase
140
+ input_str = input_str.strip().lower()
141
+
142
+ # Check if input is valid
143
+ if len(input_str) != 4 or not input_str[0].isalpha() or not input_str[1].isdigit() or \
144
+ not input_str[2].isalpha() or not input_str[3].isdigit():
145
+ return "Invalid input"
146
+
147
+ # Define mappings
148
+ letter_mapping = {
149
+ 'a': 'h', 'b': 'g', 'c': 'f', 'd': 'e',
150
+ 'e': 'd', 'f': 'c', 'g': 'b', 'h': 'a'
151
+ }
152
+ number_mapping = {
153
+ '1': '8', '2': '7', '3': '6', '4': '5',
154
+ '5': '4', '6': '3', '7': '2', '8': '1'
155
+ }
156
+
157
+ # Transform string
158
+ result = ""
159
+ for i, char in enumerate(input_str):
160
+ if i % 2 == 0: # Letters
161
+ result += letter_mapping.get(char, "Invalid")
162
+ else: # Numbers
163
+ result += number_mapping.get(char, "Invalid")
164
+
165
+ # Check for invalid transformations
166
+ if "Invalid" in result:
167
+ return "Invalid input"
168
+
169
+ return result
170
+
171
+
172
+
173
+ # Streamlit app
174
+ def main():
175
+ # download_stockfish()
176
+ st.title("Chessboard Position Detection and Move Prediction")
177
+
178
+ os.chmod("/home/user/app/stockfish-ubuntu-x86-64-sse41-popcnt", 0o755)
179
+
180
+
181
+ st.write(os.getcwd())
182
+
183
+
184
+ # User uploads an image or captures it from their camera
185
+ image_file = st.camera_input("Capture a chessboard image") or st.file_uploader("Upload a chessboard image", type=["jpg", "jpeg", "png"])
186
+
187
+ if image_file is not None:
188
+ # Save the image to a temporary file
189
+ temp_dir = "temp_images"
190
+ os.makedirs(temp_dir, exist_ok=True)
191
+ temp_file_path = os.path.join(temp_dir, "uploaded_image.jpg")
192
+ with open(temp_file_path, "wb") as f:
193
+ f.write(image_file.getbuffer())
194
+
195
+ # Process the image using its file path
196
+ processed_image = process_image(temp_file_path)
197
+
198
+ if processed_image is not None:
199
+ # Resize the image to 224x224
200
+ processed_image = cv2.resize(processed_image, (224, 224))
201
+ height, width, _ = processed_image.shape
202
+
203
+ # Initialize the YOLO model
204
+ model = YOLO("standard.pt") # Replace with your trained model weights file
205
+
206
+ # Run detection
207
+ results = model.predict(source=processed_image, save=False, save_txt=False, conf=0.6)
208
+
209
+ # Initialize the board for FEN (empty rows represented by "8")
210
+ board = [["8"] * 8 for _ in range(8)]
211
+
212
+ # Extract predictions and map to FEN board
213
+ for result in results[0].boxes:
214
+ x1, y1, x2, y2 = result.xyxy[0].tolist()
215
+ class_id = int(result.cls[0])
216
+ class_name = model.names[class_id]
217
+
218
+ # Convert class_name to FEN notation
219
+ fen_piece = FEN_MAPPING.get(class_name, None)
220
+ if not fen_piece:
221
+ continue
222
+
223
+ # Calculate the center of the bounding box
224
+ center_x = (x1 + x2) / 2
225
+ center_y = (y1 + y2) / 2
226
+
227
+ # Convert to integer pixel coordinates
228
+ pixel_x = int(center_x)
229
+ pixel_y = int(height - center_y) # Flip Y-axis for generic coordinate system
230
+
231
+ # Get grid coordinate
232
+ grid_position = get_grid_coordinate(pixel_x, pixel_y)
233
+
234
+ if grid_position != "Pixel outside grid bounds":
235
+ file = ord(grid_position[0]) - ord('a') # Column index (0-7)
236
+ rank = int(grid_position[1]) - 1 # Row index (0-7)
237
+
238
+ # Place the piece on the board
239
+ board[7 - rank][file] = fen_piece # Flip rank index for FEN
240
+
241
+ # Generate the FEN string
242
+ fen_rows = []
243
+ for row in board:
244
+ fen_row = ""
245
+ empty_count = 0
246
+ for cell in row:
247
+ if cell == "8":
248
+ empty_count += 1
249
+ else:
250
+ if empty_count > 0:
251
+ fen_row += str(empty_count)
252
+ empty_count = 0
253
+ fen_row += cell
254
+ if empty_count > 0:
255
+ fen_row += str(empty_count)
256
+ fen_rows.append(fen_row)
257
+
258
+ position_fen = "/".join(fen_rows)
259
+
260
+ # Ask the user for the next move side
261
+ move_side = st.selectbox("Select the side to move:", ["w (White)", "b (Black)"])
262
+ move_side = "w" if move_side.startswith("w") else "b"
263
+
264
+ # Append the full FEN string continuation
265
+ fen_notation = f"{position_fen} {move_side} - - 0 0"
266
+
267
+ st.subheader("Generated FEN Notation:")
268
+ st.code(fen_notation)
269
+
270
+ # Initialize the Stockfish engine
271
+ stockfish_path = os.path.join(os.getcwd(), "stockfish-ubuntu-x86-64-sse41-popcnt")
272
+ stockfish = Stockfish(
273
+ path=stockfish_path,
274
+ depth=15,
275
+ parameters={"Threads": 2, "Minimum Thinking Time": 30}
276
+ )
277
+
278
+ # Predict the next move
279
+ next_move = predict_next_move(fen_notation, stockfish)
280
+ st.subheader("Stockfish Recommended Move:")
281
+ st.write(next_move)
282
+
283
+ else:
284
+ st.error("Failed to process the image. Please try again.")
285
+
286
+ if __name__ == "__main__":
287
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ultralytics==8.3.40
2
+ stockfish==3.28.0
3
+ requests==2.32.3
segmentation.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:912bbbde63f435106d57c7416c11a49eb3e9cb93dfe71cb6f9bfaafc1a4e3683
3
+ size 6781485
standard.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2c19a7f75312af21e9e514f008a05da5ff5624590cc5a8997c977a16d2ac459
3
+ size 114375506
stockfish-ubuntu-x86-64-sse41-popcnt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee777d5045c40d4f59c85b61bd666fcc6d2533ccada53da08f8e86290e156a30
3
+ size 78735648