Evanjaa commited on
Commit
bbb40f2
1 Parent(s): 9cc1d9b

Upload train_classifier.ipynb

Browse files
Files changed (1) hide show
  1. train_classifier.ipynb +20 -5
train_classifier.ipynb CHANGED
@@ -339,7 +339,7 @@
339
  " min_samples = counts.min()\n",
340
  " # Calculate 2.0 times the minimum sample size, rounded down to the nearest integer\n",
341
  " # target_samples = int(2.0 * min_samples)\n",
342
- " target_samples = 5000\n",
343
  " \n",
344
  " indices_to_keep = np.hstack([\n",
345
  " np.random.choice(\n",
@@ -521,7 +521,7 @@
521
  "# Loss and optimizer\n",
522
  "criterion = nn.CrossEntropyLoss()\n",
523
  "optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5) \n",
524
- "lambda_l1 = 1e-3 # L1 regularization strength"
525
  ]
526
  },
527
  {
@@ -539,7 +539,7 @@
539
  "metadata": {},
540
  "outputs": [],
541
  "source": [
542
- "epochs = 50\n",
543
  "train_losses, test_losses = [], []\n",
544
  "\n",
545
  "for epoch in range(epochs):\n",
@@ -577,7 +577,7 @@
577
  " \n",
578
  " precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='weighted', zero_division=0)\n",
579
  " accuracy = accuracy_score(all_targets, all_preds) # Compute accuracy\n",
580
- " if epoch % 10==0:\n",
581
  " print(f'Epoch {epoch+1}: Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}')"
582
  ]
583
  },
@@ -620,6 +620,9 @@
620
  "metadata": {},
621
  "outputs": [],
622
  "source": [
 
 
 
623
  "conf_matrix = confusion_matrix(all_targets, all_preds)\n",
624
  "labels = [\"background\", \"tackle-live\", \"tackle-replay\",]\n",
625
  " # \"tackle-live-incomplete\", \"tackle-replay-incomplete\"]\n",
@@ -627,7 +630,19 @@
627
  "# plt.title('Confusion Matrix')\n",
628
  "plt.xlabel('Predicted Label')\n",
629
  "plt.ylabel('True Label')\n",
630
- "plt.show()"
 
 
 
 
 
 
 
 
 
 
 
 
631
  ]
632
  },
633
  {
 
339
  " min_samples = counts.min()\n",
340
  " # Calculate 2.0 times the minimum sample size, rounded down to the nearest integer\n",
341
  " # target_samples = int(2.0 * min_samples)\n",
342
+ " target_samples = 7500\n",
343
  " \n",
344
  " indices_to_keep = np.hstack([\n",
345
  " np.random.choice(\n",
 
521
  "# Loss and optimizer\n",
522
  "criterion = nn.CrossEntropyLoss()\n",
523
  "optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5) \n",
524
+ "lambda_l1 = 1e-5 # L1 regularization strength"
525
  ]
526
  },
527
  {
 
539
  "metadata": {},
540
  "outputs": [],
541
  "source": [
542
+ "epochs = 10\n",
543
  "train_losses, test_losses = [], []\n",
544
  "\n",
545
  "for epoch in range(epochs):\n",
 
577
  " \n",
578
  " precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='weighted', zero_division=0)\n",
579
  " accuracy = accuracy_score(all_targets, all_preds) # Compute accuracy\n",
580
+ " if epoch % 2==0:\n",
581
  " print(f'Epoch {epoch+1}: Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}')"
582
  ]
583
  },
 
620
  "metadata": {},
621
  "outputs": [],
622
  "source": [
623
+ "print(np.unique(all_targets, return_counts=True))\n",
624
+ "print(np.unique(all_preds, return_counts=True))\n",
625
+ "\n",
626
  "conf_matrix = confusion_matrix(all_targets, all_preds)\n",
627
  "labels = [\"background\", \"tackle-live\", \"tackle-replay\",]\n",
628
  " # \"tackle-live-incomplete\", \"tackle-replay-incomplete\"]\n",
 
630
  "# plt.title('Confusion Matrix')\n",
631
  "plt.xlabel('Predicted Label')\n",
632
  "plt.ylabel('True Label')\n",
633
+ "plt.show()\n",
634
+ "\n",
635
+ "def showClassWiseAcc(conf_matrix):\n",
636
+ " # Calculate accuracy per class\n",
637
+ " class_accuracies = conf_matrix.diagonal() / conf_matrix.sum(axis=1)\n",
638
+ "\n",
639
+ " # Prepare accuracy data for writing to file\n",
640
+ " accuracy_data = \"\\n\".join([f\"Accuracy for class {i}: {class_accuracies[i]:.4f}\" for i in range(len(class_accuracies))])\n",
641
+ "\n",
642
+ " # Print accuracy per class and write to a file\n",
643
+ " print(accuracy_data) # Print to console\n",
644
+ "\n",
645
+ "showClassWiseAcc(conf_matrix)"
646
  ]
647
  },
648
  {