Sckathach commited on
Commit
d4d76e3
1 Parent(s): c889936
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .ipynb_checkpoints
3
+
4
+ .env/
5
+ deployment_files/.*
6
+ deployment_files/client_dir/
7
+ deployment_files/server_dir/
8
+
9
+ TODO.md
10
+ .venv
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/Zamark.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="2">
8
+ <item index="0" class="java.lang.String" itemvalue="concrete-ml" />
9
+ <item index="1" class="java.lang.String" itemvalue="streamlit" />
10
+ </list>
11
+ </value>
12
+ </option>
13
+ </inspection_tool>
14
+ </profile>
15
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (Team8)" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Zamark.iml" filepath="$PROJECT_DIR$/.idea/Zamark.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
app.py CHANGED
@@ -1,4 +1,254 @@
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
1
+
2
+
3
  import streamlit as st
4
+ import hashlib
5
+ import uuid
6
+ from streamlit_card import card
7
+ import streamlit.components.v1 as components
8
+ import time
9
+ import json
10
+
11
+ def generate_mock_hash():
12
+ return hashlib.sha256(str(time.time()).encode()).hexdigest()
13
+
14
+
15
+ from utils import (
16
+ CLIENT_DIR,
17
+ CURRENT_DIR,
18
+ DEPLOYMENT_DIR,
19
+ KEYS_DIR,
20
+ INPUT_BROWSER_LIMIT,
21
+ clean_directory,
22
+ SERVER_DIR,
23
+ )
24
+
25
+ from concrete.ml.deployment import FHEModelClient
26
+
27
+ st.set_page_config(layout="wide")
28
+
29
+ st.sidebar.title("Contact")
30
+ st.sidebar.info(
31
+ """
32
+ - Reda Bellafqira
33
+ - Mehdi Ben Ghali
34
+ - Pierre-Elisée Flory
35
+ - Mohammed Lansari
36
+ - Thomas Winninger
37
+ """
38
+ )
39
+
40
+ st.title("Secure Watermarking Service")
41
+
42
+ # st.image(
43
+ # "llm_watermarking.png",
44
+ # caption="A Watermark for Large Language Models (https://doi.org/10.48550/arXiv.2301.10226)",
45
+ # )
46
+
47
+
48
+ def todo():
49
+ st.warning("Not implemented yet", icon="⚠️")
50
+
51
+
52
+ def key_gen_fn(client_id):
53
+ """
54
+ Generate keys for a given user. The keys are saved in KEYS_DIR
55
+
56
+ !!! needs a model in DEPLOYMENT_DIR as "client.zip" !!!
57
+ Args:
58
+ client_id (str): The client_id, retrieved from streamlit
59
+ """
60
+ clean_directory()
61
+
62
+ client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{client_id}")
63
+ client.load()
64
+
65
+ # Creates the private and evaluation keys on the client side
66
+ client.generate_private_and_evaluation_keys()
67
+
68
+ # Get the serialized evaluation keys
69
+ serialized_evaluation_keys = client.get_serialized_evaluation_keys()
70
+ assert isinstance(serialized_evaluation_keys, bytes)
71
+
72
+ # Save the evaluation key
73
+ evaluation_key_path = KEYS_DIR / f"{client_id}/evaluation_key"
74
+ with evaluation_key_path.open("wb") as f:
75
+ f.write(serialized_evaluation_keys)
76
+
77
+ # show bit of key
78
+ serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[
79
+ :INPUT_BROWSER_LIMIT
80
+ ]
81
+ # shpw len of key
82
+ # f"{len(serialized_evaluation_keys) / (10**6):.2f} MB"
83
+ with st.expander("Generated keys"):
84
+ st.write(f"{len(serialized_evaluation_keys) / (10**6):.2f} MB")
85
+ st.code(serialized_evaluation_keys_shorten_hex)
86
+
87
+ st.success("Keys have been generated!", icon="✅")
88
+
89
+
90
+ def gen_trigger_set(client_id, hf_id):
91
+ # input : random images seeded by client_id
92
+ # labels : binary array of the id
93
+ watermark_uuid = uuid.uuid1()
94
+ hash = hashlib.sha256()
95
+ hash.update(client_id + str(watermark_uuid))
96
+ client_seed = hash.digest()
97
+ hash = hashlib.sha256()
98
+ hash.update(hf_id + str(watermark_uuid))
99
+ hf_seed = hash.digest()
100
+
101
+ trigger_set_size = 128
102
+
103
+ trigger_set_client = [
104
+ {"input": 1, "label": digit} for digit in encode_id(client_id, trigger_set_size)
105
+ ]
106
+
107
+ todo()
108
+
109
+
110
+ def encode_id(ascii_rep, size=128):
111
+ """Encode a string id to a string of bits
112
+
113
+ Args:
114
+ ascii_rep (_type_): The id string
115
+ size (_type_): The size of the output bit string
116
+
117
+ Returns:
118
+ _type_: a string of bits
119
+ """
120
+ return "".join([format(ord(x), "b").zfill(8) for x in client_id])[:size]
121
+
122
+
123
+ def decode_id(binary_rep):
124
+ """Decode a string of bits to an ascii string
125
+
126
+ Args:
127
+ binary_rep (_type_): the binary string
128
+
129
+ Returns:
130
+ _type_: an ascii string
131
+ """
132
+ # Initializing a binary string in the form of
133
+ # 0 and 1, with base of 2
134
+ binary_int = int(binary_rep, 2)
135
+ # Getting the byte number
136
+ byte_number = binary_int.bit_length() + 7 // 8
137
+ # Getting an array of bytes
138
+ binary_array = binary_int.to_bytes(byte_number, "big")
139
+ # Converting the array into ASCII text
140
+ ascii_text = binary_array.decode()
141
+ # Getting the ASCII value
142
+ return ascii_text
143
+
144
+
145
+ def compare_id(client_id, binary_triggert_set_result):
146
+ """Compares the string id with the labels of the trigger set on the tested API
147
+
148
+ Args:
149
+ client_id (_type_): the ascii string
150
+ binary_triggert_set_result (_type_): the binary string
151
+
152
+ Returns:
153
+ _type_: _description_
154
+ """
155
+ ground_truth = encode_id(client_id, 128)
156
+
157
+ correct_bit = 0
158
+ for true_bit, real_bit in zip(ground_truth, binary_triggert_set_result):
159
+ if true_bit != real_bit:
160
+ correct_bit += 1
161
+
162
+ return correct_bit / len(binary_triggert_set_result)
163
+
164
+
165
+ def watermark(model, trigger_set):
166
+ """Watermarking function
167
+
168
+ Args:
169
+ model (_type_): The model to watermark
170
+ trigger_set (_type_): the trigger set
171
+ """
172
+ todo()
173
+
174
+ model_file_path = SERVER_DIR / "watermarked_model"
175
+ trigger_set_file_path = SERVER_DIR / "trigger_set"
176
+
177
+ # TODO: remove once model correctly watermarked
178
+ model_file_path.touch()
179
+ trigger_set_file_path.touch()
180
+
181
+ # Once the model is watermarked and dumped to files (model + trigger set), the user can download them
182
+ with open(model_file_path, "rb") as model_file:
183
+ st.download_button(
184
+ label="Download the watermarked file",
185
+ data=model_file,
186
+ mime="application/octet-stream",
187
+ )
188
+ with open(trigger_set_file_path, "rb") as trigger_set_file:
189
+ st.download_button(
190
+ label="Download the triggert set",
191
+ data=trigger_set_file,
192
+ mime="application/octet-stream",
193
+ )
194
+
195
+
196
+ st.header("Client Configuration", divider=True)
197
+
198
+ client_id = st.text_input("Identification string", "team-8-uuid")
199
+
200
+ if st.button("Generate keys"):
201
+ key_gen_fn(client_id)
202
+
203
+ st.header("Model Watermarking", divider=True)
204
+
205
+ encrypted_model = st.file_uploader("Upload your encrypted model")
206
+
207
+ if st.button("Start Watermarking"):
208
+ watermark(None, None)
209
+
210
+ st.header("Watermarking Verification", divider=True)
211
+
212
+
213
+ st.header("Update Blockchain", divider=True)
214
+
215
+ # Initialize session state to store the block data
216
+ if 'block_data' not in st.session_state:
217
+ st.session_state.block_data = None
218
+
219
+ # Button to update the blockchain
220
+ if st.button("Update Blockchain"):
221
+ previous_hash = generate_mock_hash()
222
+ timestamp = int(time.time() * 1000) # Current timestamp in milliseconds
223
+ watermarked_model_hash = generate_mock_hash()
224
+ trigger_set_hash = generate_mock_hash()
225
+
226
+ # Create the block data structure
227
+ st.session_state.block_data = {
228
+ "blockNumber": 42,
229
+ "previousHash": previous_hash,
230
+ "timestamp": timestamp,
231
+ "transactions": [
232
+ {
233
+ "type": "Watermarked Model Hash",
234
+ "hash": watermarked_model_hash
235
+ },
236
+ {
237
+ "type": "Trigger Set Hash",
238
+ "hash": trigger_set_hash
239
+ }
240
+ ]
241
+ }
242
+
243
+ st.success("Blockchain updated successfully!")
244
+
245
+ # Display the JSON if block_data exists
246
+ if st.session_state.block_data:
247
+ st.subheader("Latest Block Data (JSON)")
248
+
249
+ # Convert the data to a formatted JSON string
250
+ block_json = json.dumps(st.session_state.block_data, indent=2)
251
+
252
+ # Display the JSON
253
+ st.code(block_json, language='json')
254
 
 
 
llm_watermarking.png ADDED
utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+
5
+ INPUT_BROWSER_LIMIT = 380
6
+
7
+ CURRENT_DIR = Path(__file__).parent
8
+ DEPLOYMENT_DIR = CURRENT_DIR / "deployment_files"
9
+ KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"
10
+ CLIENT_DIR = DEPLOYMENT_DIR / "client_dir"
11
+ SERVER_DIR = DEPLOYMENT_DIR / "server_dir"
12
+
13
+ ALL_DIRS = [KEYS_DIR, CLIENT_DIR, SERVER_DIR]
14
+
15
+
16
+ def clean_directory() -> None:
17
+ """
18
+ Clear direcgtories
19
+ """
20
+ print("Cleaning...\n")
21
+ for target_dir in ALL_DIRS:
22
+ if os.path.exists(target_dir) and os.path.isdir(target_dir):
23
+ shutil.rmtree(target_dir)
24
+ target_dir.mkdir(exist_ok=True, parents=True)