File size: 2,218 Bytes
0d1a7d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import pandas as pd
import os
import re
import csv

def extract_paren(annotation):
    ents = []
    for i in range(len(annotation)):
        if annotation[i] == "[":
            ent = "["
            open_paren = 0

            for j in range(i+1, len(annotation)):
                if annotation[j] == "[":
                    open_paren += 1
                elif annotation[j] == "]":
                    if open_paren > 0:
                        open_paren -= 1
                        ent = ent[:len(ent)-3]
                    else:

                        ent += "]"
                        digit = re.search(r": [0-9]{1,3}", ent)

                        if digit:
                            matches = re.findall(r": [0-9]{1,3}", annotation[:i])
                            str_index = annotation[:i].count(" ") - len(matches)
                            ent += "|" + str(str_index)
                            ents.append(ent)
                        break
                else:
                    ent += annotation[j]
    return ents

def create_clusters(ents):
    clusters = {}

    for e in ents:
        digit_ann = re.search(r": [0-9]{1,3}", e)
        if digit_ann:
            clean_e = e.replace("[", "").replace("]", "").replace(digit_ann.group(), "")

            digit = re.search(r"[0-9]{1,3}", digit_ann.group())
            digit = int(digit.group())
            
            if digit not in clusters:
                clusters[digit] = []

            clusters[digit].append(clean_e)
        else:
            print("OH NO:", e)
            print()
        
    return clusters

headers = ["input", "model_output", "model_output_clusters"]

df = pd.read_csv("results.csv")

rows = []
for index, row in df.iterrows():
    annotation = row["model_output"]

    if isinstance(annotation, str):
        ann_ents = extract_paren(annotation)

        ann_clusters = {}
        if ann_ents:
            ann_clusters = create_clusters(ann_ents)
    else:
        ann_clusters = {}
    
        
    new_row = [row["input"], annotation, str(ann_clusters)]
    rows.append(new_row)


f = open("cluster_results.csv", "w")
writer = csv.writer(f)
writer.writerow(headers)
writer.writerows(rows)
f.close()