mahnerak
commited on
Commit
β’
ce00289
0
Parent(s):
Initial Commit π
Browse files- .dockerignore +3 -0
- .flake8 +2 -0
- .gitignore +7 -0
- CODE_OF_CONDUCT.md +80 -0
- CONTRIBUTING.md +31 -0
- Dockerfile +42 -0
- LICENSE +399 -0
- README.md +88 -0
- config/docker_hosting.json +13 -0
- config/docker_local.json +25 -0
- config/local.json +47 -0
- env.yaml +27 -0
- llm_transparency_tool/__init__.py +5 -0
- llm_transparency_tool/components/__init__.py +111 -0
- llm_transparency_tool/components/frontend/.env +6 -0
- llm_transparency_tool/components/frontend/.prettierrc +5 -0
- llm_transparency_tool/components/frontend/package.json +39 -0
- llm_transparency_tool/components/frontend/public/index.html +15 -0
- llm_transparency_tool/components/frontend/src/ContributionGraph.tsx +517 -0
- llm_transparency_tool/components/frontend/src/LlmViewer.css +77 -0
- llm_transparency_tool/components/frontend/src/Selector.tsx +154 -0
- llm_transparency_tool/components/frontend/src/common.tsx +17 -0
- llm_transparency_tool/components/frontend/src/index.tsx +39 -0
- llm_transparency_tool/components/frontend/src/react-app-env.d.ts +1 -0
- llm_transparency_tool/components/frontend/tsconfig.json +19 -0
- llm_transparency_tool/models/__init__.py +5 -0
- llm_transparency_tool/models/test_tlens_model.py +162 -0
- llm_transparency_tool/models/tlens_model.py +303 -0
- llm_transparency_tool/models/transparent_llm.py +199 -0
- llm_transparency_tool/routes/__init__.py +5 -0
- llm_transparency_tool/routes/contributions.py +201 -0
- llm_transparency_tool/routes/graph.py +163 -0
- llm_transparency_tool/routes/graph_node.py +90 -0
- llm_transparency_tool/routes/test_contributions.py +148 -0
- llm_transparency_tool/server/app.py +659 -0
- llm_transparency_tool/server/graph_selection.py +56 -0
- llm_transparency_tool/server/monitor.py +99 -0
- llm_transparency_tool/server/styles.py +107 -0
- llm_transparency_tool/server/utils.py +133 -0
- pyproject.toml +2 -0
- sample_input.txt +3 -0
- setup.py +13 -0
.dockerignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
**/.git
|
2 |
+
**/node_modules
|
3 |
+
**/.mypy_cache
|
.flake8
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
max-line-length = 120
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**/frontend/node_modules*
|
2 |
+
**/frontend/build/
|
3 |
+
**/frontend/.yarn*
|
4 |
+
.vscode/
|
5 |
+
.mypy_cache/
|
6 |
+
__pycache__/
|
7 |
+
.DS_Store
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
57 |
+
the project or its community.
|
58 |
+
|
59 |
+
## Enforcement
|
60 |
+
|
61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
+
reported by contacting the project team at <[email protected]>. All
|
63 |
+
complaints will be reviewed and investigated and will result in a response that
|
64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
66 |
+
Further details of specific enforcement policies may be posted separately.
|
67 |
+
|
68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
69 |
+
faith may face temporary or permanent repercussions as determined by other
|
70 |
+
members of the project's leadership.
|
71 |
+
|
72 |
+
## Attribution
|
73 |
+
|
74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
76 |
+
|
77 |
+
[homepage]: https://www.contributor-covenant.org
|
78 |
+
|
79 |
+
For answers to common questions about this code of conduct, see
|
80 |
+
https://www.contributor-covenant.org/faq
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to llm-transparency-tool
|
2 |
+
We want to make contributing to this project as easy and transparent as
|
3 |
+
possible.
|
4 |
+
|
5 |
+
## Pull Requests
|
6 |
+
We actively welcome your pull requests.
|
7 |
+
|
8 |
+
1. Fork the repo and create your branch from `main`.
|
9 |
+
2. If you've added code that should be tested, add tests.
|
10 |
+
3. If you've changed APIs, update the documentation.
|
11 |
+
4. Ensure the test suite passes.
|
12 |
+
5. Make sure your code lints.
|
13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
14 |
+
|
15 |
+
## Contributor License Agreement ("CLA")
|
16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
17 |
+
to do this once to work on any of Facebook's open source projects.
|
18 |
+
|
19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
20 |
+
|
21 |
+
## Issues
|
22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
24 |
+
|
25 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
26 |
+
disclosure of security bugs. In those cases, please go through the process
|
27 |
+
outlined on that page and do not file a public issue.
|
28 |
+
|
29 |
+
## License
|
30 |
+
By contributing to llm-transparency-tool, you agree that your contributions will be licensed
|
31 |
+
under the LICENSE file in the root directory of this source tree.
|
Dockerfile
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
|
8 |
+
|
9 |
+
RUN apt-get update && apt-get install -y \
|
10 |
+
wget \
|
11 |
+
git \
|
12 |
+
&& apt-get clean \
|
13 |
+
&& rm -rf /var/lib/apt/lists/*
|
14 |
+
|
15 |
+
RUN useradd -m -u 1000 user
|
16 |
+
USER user
|
17 |
+
|
18 |
+
ENV HOME=/home/user
|
19 |
+
|
20 |
+
RUN wget -P /tmp \
|
21 |
+
"https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Mambaforge-23.11.0-0-Linux-x86_64.sh" \
|
22 |
+
&& bash /tmp/Mambaforge-23.11.0-0-Linux-x86_64.sh -b -p $HOME/mambaforge3 \
|
23 |
+
&& rm /tmp/Mambaforge-23.11.0-0-Linux-x86_64.sh
|
24 |
+
ENV PATH $HOME/mambaforge3/bin:$PATH
|
25 |
+
|
26 |
+
WORKDIR $HOME
|
27 |
+
|
28 |
+
ENV REPO=$HOME/llm-transparency-tool
|
29 |
+
COPY --chown=user . $REPO
|
30 |
+
|
31 |
+
WORKDIR $REPO
|
32 |
+
|
33 |
+
RUN mamba env create --name llmtt -f env.yaml -y
|
34 |
+
ENV PATH $HOME/mambaforge3/envs/llmtt/bin:$PATH
|
35 |
+
RUN pip install -e .
|
36 |
+
|
37 |
+
RUN cd llm_transparency_tool/components/frontend \
|
38 |
+
&& yarn install \
|
39 |
+
&& yarn build
|
40 |
+
|
41 |
+
EXPOSE 7860
|
42 |
+
CMD ["streamlit", "run", "llm_transparency_tool/server/app.py", "--server.port=7860", "--server.address=0.0.0.0", "--theme.font=Inconsolata", "--", "config/docker_hosting.json"]
|
LICENSE
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More_considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
58 |
+
License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
63 |
+
License"). To the extent this Public License may be interpreted as a
|
64 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
65 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
66 |
+
such rights in consideration of benefits the Licensor receives from
|
67 |
+
making the Licensed Material available under these terms and
|
68 |
+
conditions.
|
69 |
+
|
70 |
+
Section 1 -- Definitions.
|
71 |
+
|
72 |
+
a. Adapted Material means material subject to Copyright and Similar
|
73 |
+
Rights that is derived from or based upon the Licensed Material
|
74 |
+
and in which the Licensed Material is translated, altered,
|
75 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
76 |
+
permission under the Copyright and Similar Rights held by the
|
77 |
+
Licensor. For purposes of this Public License, where the Licensed
|
78 |
+
Material is a musical work, performance, or sound recording,
|
79 |
+
Adapted Material is always produced where the Licensed Material is
|
80 |
+
synched in timed relation with a moving image.
|
81 |
+
|
82 |
+
b. Adapter's License means the license You apply to Your Copyright
|
83 |
+
and Similar Rights in Your contributions to Adapted Material in
|
84 |
+
accordance with the terms and conditions of this Public License.
|
85 |
+
|
86 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
87 |
+
closely related to copyright including, without limitation,
|
88 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
89 |
+
Rights, without regard to how the rights are labeled or
|
90 |
+
categorized. For purposes of this Public License, the rights
|
91 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
92 |
+
Rights.
|
93 |
+
d. Effective Technological Measures means those measures that, in the
|
94 |
+
absence of proper authority, may not be circumvented under laws
|
95 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
96 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
97 |
+
agreements.
|
98 |
+
|
99 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
100 |
+
any other exception or limitation to Copyright and Similar Rights
|
101 |
+
that applies to Your use of the Licensed Material.
|
102 |
+
|
103 |
+
f. Licensed Material means the artistic or literary work, database,
|
104 |
+
or other material to which the Licensor applied this Public
|
105 |
+
License.
|
106 |
+
|
107 |
+
g. Licensed Rights means the rights granted to You subject to the
|
108 |
+
terms and conditions of this Public License, which are limited to
|
109 |
+
all Copyright and Similar Rights that apply to Your use of the
|
110 |
+
Licensed Material and that the Licensor has authority to license.
|
111 |
+
|
112 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
113 |
+
under this Public License.
|
114 |
+
|
115 |
+
i. NonCommercial means not primarily intended for or directed towards
|
116 |
+
commercial advantage or monetary compensation. For purposes of
|
117 |
+
this Public License, the exchange of the Licensed Material for
|
118 |
+
other material subject to Copyright and Similar Rights by digital
|
119 |
+
file-sharing or similar means is NonCommercial provided there is
|
120 |
+
no payment of monetary compensation in connection with the
|
121 |
+
exchange.
|
122 |
+
|
123 |
+
j. Share means to provide material to the public by any means or
|
124 |
+
process that requires permission under the Licensed Rights, such
|
125 |
+
as reproduction, public display, public performance, distribution,
|
126 |
+
dissemination, communication, or importation, and to make material
|
127 |
+
available to the public including in ways that members of the
|
128 |
+
public may access the material from a place and at a time
|
129 |
+
individually chosen by them.
|
130 |
+
|
131 |
+
k. Sui Generis Database Rights means rights other than copyright
|
132 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
133 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
134 |
+
as amended and/or succeeded, as well as other essentially
|
135 |
+
equivalent rights anywhere in the world.
|
136 |
+
|
137 |
+
l. You means the individual or entity exercising the Licensed Rights
|
138 |
+
under this Public License. Your has a corresponding meaning.
|
139 |
+
|
140 |
+
Section 2 -- Scope.
|
141 |
+
|
142 |
+
a. License grant.
|
143 |
+
|
144 |
+
1. Subject to the terms and conditions of this Public License,
|
145 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
146 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
147 |
+
exercise the Licensed Rights in the Licensed Material to:
|
148 |
+
|
149 |
+
a. reproduce and Share the Licensed Material, in whole or
|
150 |
+
in part, for NonCommercial purposes only; and
|
151 |
+
|
152 |
+
b. produce, reproduce, and Share Adapted Material for
|
153 |
+
NonCommercial purposes only.
|
154 |
+
|
155 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
156 |
+
Exceptions and Limitations apply to Your use, this Public
|
157 |
+
License does not apply, and You do not need to comply with
|
158 |
+
its terms and conditions.
|
159 |
+
|
160 |
+
3. Term. The term of this Public License is specified in Section
|
161 |
+
6(a).
|
162 |
+
|
163 |
+
4. Media and formats; technical modifications allowed. The
|
164 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
165 |
+
all media and formats whether now known or hereafter created,
|
166 |
+
and to make technical modifications necessary to do so. The
|
167 |
+
Licensor waives and/or agrees not to assert any right or
|
168 |
+
authority to forbid You from making technical modifications
|
169 |
+
necessary to exercise the Licensed Rights, including
|
170 |
+
technical modifications necessary to circumvent Effective
|
171 |
+
Technological Measures. For purposes of this Public License,
|
172 |
+
simply making modifications authorized by this Section 2(a)
|
173 |
+
(4) never produces Adapted Material.
|
174 |
+
|
175 |
+
5. Downstream recipients.
|
176 |
+
|
177 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
178 |
+
recipient of the Licensed Material automatically
|
179 |
+
receives an offer from the Licensor to exercise the
|
180 |
+
Licensed Rights under the terms and conditions of this
|
181 |
+
Public License.
|
182 |
+
|
183 |
+
b. No downstream restrictions. You may not offer or impose
|
184 |
+
any additional or different terms or conditions on, or
|
185 |
+
apply any Effective Technological Measures to, the
|
186 |
+
Licensed Material if doing so restricts exercise of the
|
187 |
+
Licensed Rights by any recipient of the Licensed
|
188 |
+
Material.
|
189 |
+
|
190 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
191 |
+
may be construed as permission to assert or imply that You
|
192 |
+
are, or that Your use of the Licensed Material is, connected
|
193 |
+
with, or sponsored, endorsed, or granted official status by,
|
194 |
+
the Licensor or others designated to receive attribution as
|
195 |
+
provided in Section 3(a)(1)(A)(i).
|
196 |
+
|
197 |
+
b. Other rights.
|
198 |
+
|
199 |
+
1. Moral rights, such as the right of integrity, are not
|
200 |
+
licensed under this Public License, nor are publicity,
|
201 |
+
privacy, and/or other similar personality rights; however, to
|
202 |
+
the extent possible, the Licensor waives and/or agrees not to
|
203 |
+
assert any such rights held by the Licensor to the limited
|
204 |
+
extent necessary to allow You to exercise the Licensed
|
205 |
+
Rights, but not otherwise.
|
206 |
+
|
207 |
+
2. Patent and trademark rights are not licensed under this
|
208 |
+
Public License.
|
209 |
+
|
210 |
+
3. To the extent possible, the Licensor waives any right to
|
211 |
+
collect royalties from You for the exercise of the Licensed
|
212 |
+
Rights, whether directly or through a collecting society
|
213 |
+
under any voluntary or waivable statutory or compulsory
|
214 |
+
licensing scheme. In all other cases the Licensor expressly
|
215 |
+
reserves any right to collect such royalties, including when
|
216 |
+
the Licensed Material is used other than for NonCommercial
|
217 |
+
purposes.
|
218 |
+
|
219 |
+
Section 3 -- License Conditions.
|
220 |
+
|
221 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
222 |
+
following conditions.
|
223 |
+
|
224 |
+
a. Attribution.
|
225 |
+
|
226 |
+
1. If You Share the Licensed Material (including in modified
|
227 |
+
form), You must:
|
228 |
+
|
229 |
+
a. retain the following if it is supplied by the Licensor
|
230 |
+
with the Licensed Material:
|
231 |
+
|
232 |
+
i. identification of the creator(s) of the Licensed
|
233 |
+
Material and any others designated to receive
|
234 |
+
attribution, in any reasonable manner requested by
|
235 |
+
the Licensor (including by pseudonym if
|
236 |
+
designated);
|
237 |
+
|
238 |
+
ii. a copyright notice;
|
239 |
+
|
240 |
+
iii. a notice that refers to this Public License;
|
241 |
+
|
242 |
+
iv. a notice that refers to the disclaimer of
|
243 |
+
warranties;
|
244 |
+
|
245 |
+
v. a URI or hyperlink to the Licensed Material to the
|
246 |
+
extent reasonably practicable;
|
247 |
+
|
248 |
+
b. indicate if You modified the Licensed Material and
|
249 |
+
retain an indication of any previous modifications; and
|
250 |
+
|
251 |
+
c. indicate the Licensed Material is licensed under this
|
252 |
+
Public License, and include the text of, or the URI or
|
253 |
+
hyperlink to, this Public License.
|
254 |
+
|
255 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
256 |
+
reasonable manner based on the medium, means, and context in
|
257 |
+
which You Share the Licensed Material. For example, it may be
|
258 |
+
reasonable to satisfy the conditions by providing a URI or
|
259 |
+
hyperlink to a resource that includes the required
|
260 |
+
information.
|
261 |
+
|
262 |
+
3. If requested by the Licensor, You must remove any of the
|
263 |
+
information required by Section 3(a)(1)(A) to the extent
|
264 |
+
reasonably practicable.
|
265 |
+
|
266 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
267 |
+
License You apply must not prevent recipients of the Adapted
|
268 |
+
Material from complying with this Public License.
|
269 |
+
|
270 |
+
Section 4 -- Sui Generis Database Rights.
|
271 |
+
|
272 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
273 |
+
apply to Your use of the Licensed Material:
|
274 |
+
|
275 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
276 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
277 |
+
portion of the contents of the database for NonCommercial purposes
|
278 |
+
only;
|
279 |
+
|
280 |
+
b. if You include all or a substantial portion of the database
|
281 |
+
contents in a database in which You have Sui Generis Database
|
282 |
+
Rights, then the database in which You have Sui Generis Database
|
283 |
+
Rights (but not its individual contents) is Adapted Material; and
|
284 |
+
|
285 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
286 |
+
all or a substantial portion of the contents of the database.
|
287 |
+
|
288 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
289 |
+
replace Your obligations under this Public License where the Licensed
|
290 |
+
Rights include other Copyright and Similar Rights.
|
291 |
+
|
292 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
293 |
+
|
294 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
295 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
296 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
297 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
298 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
299 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
300 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
301 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
302 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
303 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
304 |
+
|
305 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
306 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
307 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
308 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
309 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
310 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
311 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
312 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
313 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
314 |
+
|
315 |
+
c. The disclaimer of warranties and limitation of liability provided
|
316 |
+
above shall be interpreted in a manner that, to the extent
|
317 |
+
possible, most closely approximates an absolute disclaimer and
|
318 |
+
waiver of all liability.
|
319 |
+
|
320 |
+
Section 6 -- Term and Termination.
|
321 |
+
|
322 |
+
a. This Public License applies for the term of the Copyright and
|
323 |
+
Similar Rights licensed here. However, if You fail to comply with
|
324 |
+
this Public License, then Your rights under this Public License
|
325 |
+
terminate automatically.
|
326 |
+
|
327 |
+
b. Where Your right to use the Licensed Material has terminated under
|
328 |
+
Section 6(a), it reinstates:
|
329 |
+
|
330 |
+
1. automatically as of the date the violation is cured, provided
|
331 |
+
it is cured within 30 days of Your discovery of the
|
332 |
+
violation; or
|
333 |
+
|
334 |
+
2. upon express reinstatement by the Licensor.
|
335 |
+
|
336 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
337 |
+
right the Licensor may have to seek remedies for Your violations
|
338 |
+
of this Public License.
|
339 |
+
|
340 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
341 |
+
Licensed Material under separate terms or conditions or stop
|
342 |
+
distributing the Licensed Material at any time; however, doing so
|
343 |
+
will not terminate this Public License.
|
344 |
+
|
345 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
346 |
+
License.
|
347 |
+
|
348 |
+
Section 7 -- Other Terms and Conditions.
|
349 |
+
|
350 |
+
a. The Licensor shall not be bound by any additional or different
|
351 |
+
terms or conditions communicated by You unless expressly agreed.
|
352 |
+
|
353 |
+
b. Any arrangements, understandings, or agreements regarding the
|
354 |
+
Licensed Material not stated herein are separate from and
|
355 |
+
independent of the terms and conditions of this Public License.
|
356 |
+
|
357 |
+
Section 8 -- Interpretation.
|
358 |
+
|
359 |
+
a. For the avoidance of doubt, this Public License does not, and
|
360 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
361 |
+
conditions on any use of the Licensed Material that could lawfully
|
362 |
+
be made without permission under this Public License.
|
363 |
+
|
364 |
+
b. To the extent possible, if any provision of this Public License is
|
365 |
+
deemed unenforceable, it shall be automatically reformed to the
|
366 |
+
minimum extent necessary to make it enforceable. If the provision
|
367 |
+
cannot be reformed, it shall be severed from this Public License
|
368 |
+
without affecting the enforceability of the remaining terms and
|
369 |
+
conditions.
|
370 |
+
|
371 |
+
c. No term or condition of this Public License will be waived and no
|
372 |
+
failure to comply consented to unless expressly agreed to by the
|
373 |
+
Licensor.
|
374 |
+
|
375 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
376 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
377 |
+
that apply to the Licensor or You, including from the legal
|
378 |
+
processes of any jurisdiction or authority.
|
379 |
+
|
380 |
+
=======================================================================
|
381 |
+
|
382 |
+
Creative Commons is not a party to its public
|
383 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
384 |
+
its public licenses to material it publishes and in those instances
|
385 |
+
will be considered the βLicensor.β The text of the Creative Commons
|
386 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
387 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
388 |
+
material is shared under a Creative Commons public license or as
|
389 |
+
otherwise permitted by the Creative Commons policies published at
|
390 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
391 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
392 |
+
of Creative Commons without its prior written consent including,
|
393 |
+
without limitation, in connection with any unauthorized modifications
|
394 |
+
to any of its public licenses or any other arrangements,
|
395 |
+
understandings, or agreements concerning use of licensed material. For
|
396 |
+
the avoidance of doubt, this paragraph does not form part of the
|
397 |
+
public licenses.
|
398 |
+
|
399 |
+
Creative Commons may be contacted at creativecommons.org.
|
README.md
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h1>
|
2 |
+
<img width="500" alt="LLM Transparency Tool" src="https://github.com/facebookresearch/llm-transparency-tool/assets/1367529/4bbf2544-88de-4576-9622-6047a056c5c8">
|
3 |
+
</h1>
|
4 |
+
|
5 |
+
<img width="832" alt="screenshot" src="https://github.com/facebookresearch/llm-transparency-tool/assets/1367529/78f6f9e2-fe76-4ded-bb78-a57f64f4ac3a">
|
6 |
+
|
7 |
+
|
8 |
+
## Key functionality
|
9 |
+
|
10 |
+
* Choose your model, choose or add your prompt, run the inference.
|
11 |
+
* Browse contribution graph.
|
12 |
+
* Select the token to build the graph from.
|
13 |
+
* Tune the contribution threshold.
|
14 |
+
* Select representation of any token after any block.
|
15 |
+
* For the representation, see its projection to the output vocabulary, see which tokens
|
16 |
+
were promoted/suppressed but the previous block.
|
17 |
+
* The following things are clickable:
|
18 |
+
* Edges. That shows more info about the contributing attention head.
|
19 |
+
* Heads when an edge is selected. You can see what this head is promoting/suppressing.
|
20 |
+
* FFN blocks (little squares on the graph).
|
21 |
+
* Neurons when an FFN block is selected.
|
22 |
+
|
23 |
+
|
24 |
+
## Installation
|
25 |
+
|
26 |
+
### Dockerized running
|
27 |
+
```bash
|
28 |
+
# From the repository root directory
|
29 |
+
docker build -t llm_transparency_tool .
|
30 |
+
docker run --rm -p 7860:7860 llm_transparency_tool
|
31 |
+
```
|
32 |
+
|
33 |
+
### Local Installation
|
34 |
+
|
35 |
+
|
36 |
+
```bash
|
37 |
+
# download
|
38 |
+
git clone [email protected]:facebookresearch/llm-transparency-tool.git
|
39 |
+
cd llm-transparency-tool
|
40 |
+
|
41 |
+
# install the necessary packages
|
42 |
+
conda env create --name llmtt -f env.yaml
|
43 |
+
# install the `llm_transparency_tool` package
|
44 |
+
pip install -e .
|
45 |
+
|
46 |
+
# now, we need to build the frontend
|
47 |
+
# don't worry, even `yarn` comes preinstalled by `env.yaml`
|
48 |
+
cd llm_transparency_tool/components/frontend
|
49 |
+
yarn install
|
50 |
+
yarn build
|
51 |
+
```
|
52 |
+
|
53 |
+
### Launch
|
54 |
+
|
55 |
+
```bash
|
56 |
+
streamlit run llm_transparency_tool/server/app.py -- config/local.json
|
57 |
+
```
|
58 |
+
|
59 |
+
|
60 |
+
## Adding support for your LLM
|
61 |
+
|
62 |
+
Initially, the tool allows you to select from just a handful of models. Here are the
|
63 |
+
options you can try for using your model in the tool, from least to most
|
64 |
+
effort.
|
65 |
+
|
66 |
+
|
67 |
+
### The model is already supported by TransformerLens
|
68 |
+
|
69 |
+
Full list of models is [here](https://github.com/neelnanda-io/TransformerLens/blob/0825c5eb4196e7ad72d28bcf4e615306b3897490/transformer_lens/loading_from_pretrained.py#L18).
|
70 |
+
In this case, the model can be added to the configuration json file.
|
71 |
+
|
72 |
+
|
73 |
+
### Tuned version of a model supported by TransformerLens
|
74 |
+
|
75 |
+
Add the official name of the model to the config along with the location to read the
|
76 |
+
weights from.
|
77 |
+
|
78 |
+
|
79 |
+
### The model is not supported by TransformerLens
|
80 |
+
|
81 |
+
In this case the UI wouldn't know how to create proper hooks for the model. You'd need
|
82 |
+
to implement your version of [TransparentLlm](./llm_transparency_tool/models/transparent_llm.py#L28) class and alter the
|
83 |
+
Streamlit app to use your implementation.
|
84 |
+
|
85 |
+
|
86 |
+
## License
|
87 |
+
This code is made available under a [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license, as found in the LICENSE file.
|
88 |
+
However you may have other legal obligations that govern your use of other content, such as the terms of service for third-party models.
|
config/docker_hosting.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"allow_loading_dataset_files": false,
|
3 |
+
"max_user_string_length": 100,
|
4 |
+
"preloaded_dataset_filename": "sample_input.txt",
|
5 |
+
"debug": false,
|
6 |
+
"demo_mode": true,
|
7 |
+
"models": {
|
8 |
+
"facebook/opt-125m": null,
|
9 |
+
"gpt2": null,
|
10 |
+
"distilgpt2": null
|
11 |
+
},
|
12 |
+
"default_model": "gpt2"
|
13 |
+
}
|
config/docker_local.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"allow_loading_dataset_files": true,
|
3 |
+
"preloaded_dataset_filename": "sample_input.txt",
|
4 |
+
"debug": true,
|
5 |
+
"models": {
|
6 |
+
"": null,
|
7 |
+
"facebook/opt-125m": null,
|
8 |
+
"facebook/opt-1.3b": null,
|
9 |
+
"facebook/opt-2.7b": null,
|
10 |
+
"facebook/opt-6.7b": null,
|
11 |
+
"facebook/opt-13b": null,
|
12 |
+
"facebook/opt-30b": null,
|
13 |
+
"meta-llama/Llama-2-7b-hf": null,
|
14 |
+
"meta-llama/Llama-2-7b-chat-hf": null,
|
15 |
+
"meta-llama/Llama-2-13b-hf": null,
|
16 |
+
"meta-llama/Llama-2-13b-chat-hf": null,
|
17 |
+
"gpt2": null,
|
18 |
+
"gpt2-medium": null,
|
19 |
+
"gpt2-large": null,
|
20 |
+
"gpt2-xl": null,
|
21 |
+
"distilgpt2": null
|
22 |
+
},
|
23 |
+
"default_model": "distilgpt2",
|
24 |
+
"demo_mode": false
|
25 |
+
}
|
config/local.json
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"allow_loading_dataset_files": true,
|
3 |
+
"preloaded_dataset_filename": "sample_input.txt",
|
4 |
+
"debug": true,
|
5 |
+
"models": {
|
6 |
+
"": null,
|
7 |
+
|
8 |
+
"gpt2": null,
|
9 |
+
"distilgpt2": null,
|
10 |
+
"facebook/opt-125m": null,
|
11 |
+
"facebook/opt-1.3b": null,
|
12 |
+
"EleutherAI/gpt-neo-125M": null,
|
13 |
+
"Qwen/Qwen-1_8B": null,
|
14 |
+
"Qwen/Qwen1.5-0.5B": null,
|
15 |
+
"Qwen/Qwen1.5-0.5B-Chat": null,
|
16 |
+
"Qwen/Qwen1.5-1.8B": null,
|
17 |
+
"Qwen/Qwen1.5-1.8B-Chat": null,
|
18 |
+
"microsoft/phi-1": null,
|
19 |
+
"microsoft/phi-1_5": null,
|
20 |
+
"microsoft/phi-2": null,
|
21 |
+
|
22 |
+
"meta-llama/Llama-2-7b-hf": null,
|
23 |
+
"meta-llama/Llama-2-7b-chat-hf": null,
|
24 |
+
|
25 |
+
"meta-llama/Llama-2-13b-hf": null,
|
26 |
+
"meta-llama/Llama-2-13b-chat-hf": null,
|
27 |
+
|
28 |
+
|
29 |
+
"gpt2-medium": null,
|
30 |
+
"gpt2-large": null,
|
31 |
+
"gpt2-xl": null,
|
32 |
+
|
33 |
+
"mistralai/Mistral-7B-v0.1": null,
|
34 |
+
"mistralai/Mistral-7B-Instruct-v0.1": null,
|
35 |
+
"mistralai/Mistral-7B-Instruct-v0.2": null,
|
36 |
+
|
37 |
+
"google/gemma-7b": null,
|
38 |
+
"google/gemma-2b": null,
|
39 |
+
|
40 |
+
"facebook/opt-2.7b": null,
|
41 |
+
"facebook/opt-6.7b": null,
|
42 |
+
"facebook/opt-13b": null,
|
43 |
+
"facebook/opt-30b": null
|
44 |
+
},
|
45 |
+
"default_model": "",
|
46 |
+
"demo_mode": false
|
47 |
+
}
|
env.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llmtt
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
dependencies:
|
7 |
+
- python
|
8 |
+
- pytorch
|
9 |
+
- pytorch-cuda=11.8
|
10 |
+
- nodejs
|
11 |
+
- yarn
|
12 |
+
- pip
|
13 |
+
- pip:
|
14 |
+
- datasets
|
15 |
+
- einops
|
16 |
+
- fancy_einsum
|
17 |
+
- jaxtyping
|
18 |
+
- networkx
|
19 |
+
- plotly
|
20 |
+
- pyinstrument
|
21 |
+
- setuptools
|
22 |
+
- streamlit
|
23 |
+
- streamlit_extras
|
24 |
+
- tokenizers
|
25 |
+
- transformer_lens
|
26 |
+
- transformers
|
27 |
+
- pytest # fixes wrong dependencies of transformer_lens
|
llm_transparency_tool/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
llm_transparency_tool/components/__init__.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
from typing import List, Optional
|
9 |
+
|
10 |
+
import networkx as nx
|
11 |
+
import streamlit.components.v1 as components
|
12 |
+
|
13 |
+
from llm_transparency_tool.models.transparent_llm import ModelInfo
|
14 |
+
from llm_transparency_tool.server.graph_selection import GraphSelection, UiGraphNode
|
15 |
+
|
16 |
+
_RELEASE = True
|
17 |
+
|
18 |
+
if _RELEASE:
|
19 |
+
parent_dir = os.path.dirname(os.path.abspath(__file__))
|
20 |
+
config = {
|
21 |
+
"path": os.path.join(parent_dir, "frontend/build"),
|
22 |
+
}
|
23 |
+
else:
|
24 |
+
config = {
|
25 |
+
"url": "http://localhost:3001",
|
26 |
+
}
|
27 |
+
|
28 |
+
_component_func = components.declare_component("contribution_graph", **config)
|
29 |
+
|
30 |
+
|
31 |
+
def is_node_valid(node: UiGraphNode, n_layers: int, n_tokens: int):
|
32 |
+
return node.layer < n_layers and node.token < n_tokens
|
33 |
+
|
34 |
+
|
35 |
+
def is_selection_valid(s: GraphSelection, n_layers: int, n_tokens: int):
|
36 |
+
if not s:
|
37 |
+
return True
|
38 |
+
if s.node:
|
39 |
+
if not is_node_valid(s.node, n_layers, n_tokens):
|
40 |
+
return False
|
41 |
+
if s.edge:
|
42 |
+
for node in [s.edge.source, s.edge.target]:
|
43 |
+
if not is_node_valid(node, n_layers, n_tokens):
|
44 |
+
return False
|
45 |
+
return True
|
46 |
+
|
47 |
+
|
48 |
+
def contribution_graph(
|
49 |
+
model_info: ModelInfo,
|
50 |
+
tokens: List[str],
|
51 |
+
graphs: List[nx.Graph],
|
52 |
+
key: str,
|
53 |
+
) -> Optional[GraphSelection]:
|
54 |
+
"""Create a new instance of contribution graph.
|
55 |
+
|
56 |
+
Returns selected graph node or None if nothing was selected.
|
57 |
+
"""
|
58 |
+
assert len(tokens) == len(graphs)
|
59 |
+
|
60 |
+
result = _component_func(
|
61 |
+
component="graph",
|
62 |
+
model_info=model_info.__dict__,
|
63 |
+
tokens=tokens,
|
64 |
+
edges_per_token=[nx.node_link_data(g)["links"] for g in graphs],
|
65 |
+
default=None,
|
66 |
+
key=key,
|
67 |
+
)
|
68 |
+
|
69 |
+
selection = GraphSelection.from_json(result)
|
70 |
+
|
71 |
+
n_tokens = len(tokens)
|
72 |
+
n_layers = model_info.n_layers
|
73 |
+
# We need this extra protection because even though the component has to check for
|
74 |
+
# the validity of the selection, sometimes it allows invalid output. It's some
|
75 |
+
# unexpected effect that has something to do with React and how the output value is
|
76 |
+
# set for the component.
|
77 |
+
if not is_selection_valid(selection, n_layers, n_tokens):
|
78 |
+
selection = None
|
79 |
+
|
80 |
+
return selection
|
81 |
+
|
82 |
+
|
83 |
+
def selector(
|
84 |
+
items: List[str],
|
85 |
+
indices: List[int],
|
86 |
+
temperatures: Optional[List[float]],
|
87 |
+
preselected_index: Optional[int],
|
88 |
+
key: str,
|
89 |
+
) -> Optional[int]:
|
90 |
+
"""Create a new instance of selector.
|
91 |
+
|
92 |
+
Returns selected item index.
|
93 |
+
"""
|
94 |
+
n = len(items)
|
95 |
+
assert n == len(indices)
|
96 |
+
items = [{"index": i, "text": s} for s, i in zip(items, indices)]
|
97 |
+
|
98 |
+
if temperatures is not None:
|
99 |
+
assert n == len(temperatures)
|
100 |
+
for i, t in enumerate(temperatures):
|
101 |
+
items[i]["temperature"] = t
|
102 |
+
|
103 |
+
result = _component_func(
|
104 |
+
component="selector",
|
105 |
+
items=items,
|
106 |
+
preselected_index=preselected_index,
|
107 |
+
default=None,
|
108 |
+
key=key,
|
109 |
+
)
|
110 |
+
|
111 |
+
return None if result is None else int(result)
|
llm_transparency_tool/components/frontend/.env
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Run the component's dev server on :3001
|
2 |
+
# (The Streamlit dev server already runs on :3000)
|
3 |
+
PORT=3001
|
4 |
+
|
5 |
+
# Don't automatically open the web browser on `npm run start`.
|
6 |
+
BROWSER=none
|
llm_transparency_tool/components/frontend/.prettierrc
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"endOfLine": "lf",
|
3 |
+
"semi": false,
|
4 |
+
"trailingComma": "es5"
|
5 |
+
}
|
llm_transparency_tool/components/frontend/package.json
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "contribution_graph",
|
3 |
+
"version": "0.1.0",
|
4 |
+
"private": true,
|
5 |
+
"dependencies": {
|
6 |
+
"@types/d3": "^7.4.0",
|
7 |
+
"d3": "^7.8.5",
|
8 |
+
"react": "^18.2.0",
|
9 |
+
"react-dom": "^18.2.0",
|
10 |
+
"streamlit-component-lib": "^2.0.0"
|
11 |
+
},
|
12 |
+
"scripts": {
|
13 |
+
"start": "react-scripts start",
|
14 |
+
"build": "react-scripts build",
|
15 |
+
"test": "react-scripts test",
|
16 |
+
"eject": "react-scripts eject"
|
17 |
+
},
|
18 |
+
"browserslist": {
|
19 |
+
"production": [
|
20 |
+
">0.2%",
|
21 |
+
"not dead",
|
22 |
+
"not op_mini all"
|
23 |
+
],
|
24 |
+
"development": [
|
25 |
+
"last 1 chrome version",
|
26 |
+
"last 1 firefox version",
|
27 |
+
"last 1 safari version"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
"homepage": ".",
|
31 |
+
"devDependencies": {
|
32 |
+
"@types/node": "^20.11.17",
|
33 |
+
"@types/react": "^18.2.55",
|
34 |
+
"@types/react-dom": "^18.2.19",
|
35 |
+
"eslint-config-react-app": "^7.0.1",
|
36 |
+
"react-scripts": "^5.0.1",
|
37 |
+
"typescript": "^5.3.3"
|
38 |
+
}
|
39 |
+
}
|
llm_transparency_tool/components/frontend/public/index.html
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<title>Contribution Graph for Streamlit</title>
|
5 |
+
<meta charset="UTF-8" />
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
7 |
+
<meta name="theme-color" content="#000000" />
|
8 |
+
<meta name="description" content="Contribution Graph for Streamlit" />
|
9 |
+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" />
|
10 |
+
</head>
|
11 |
+
<body>
|
12 |
+
<noscript>You need to enable JavaScript to run this app.</noscript>
|
13 |
+
<div id="root"></div>
|
14 |
+
</body>
|
15 |
+
</html>
|
llm_transparency_tool/components/frontend/src/ContributionGraph.tsx
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
* All rights reserved.
|
4 |
+
*
|
5 |
+
* This source code is licensed under the license found in the
|
6 |
+
* LICENSE file in the root directory of this source tree.
|
7 |
+
*/
|
8 |
+
|
9 |
+
import {
|
10 |
+
ComponentProps,
|
11 |
+
Streamlit,
|
12 |
+
withStreamlitConnection,
|
13 |
+
} from 'streamlit-component-lib'
|
14 |
+
import React, { useEffect, useMemo, useRef, useState } from 'react';
|
15 |
+
import * as d3 from 'd3';
|
16 |
+
|
17 |
+
import {
|
18 |
+
Label,
|
19 |
+
Point,
|
20 |
+
} from './common';
|
21 |
+
import './LlmViewer.css';
|
22 |
+
|
23 |
+
export const renderParams = {
|
24 |
+
cellH: 32,
|
25 |
+
cellW: 32,
|
26 |
+
attnSize: 8,
|
27 |
+
afterFfnSize: 8,
|
28 |
+
ffnSize: 6,
|
29 |
+
tokenSelectorSize: 16,
|
30 |
+
layerCornerRadius: 6,
|
31 |
+
}
|
32 |
+
|
33 |
+
interface Cell {
|
34 |
+
layer: number
|
35 |
+
token: number
|
36 |
+
}
|
37 |
+
|
38 |
+
enum CellItem {
|
39 |
+
AfterAttn = 'after_attn',
|
40 |
+
AfterFfn = 'after_ffn',
|
41 |
+
Ffn = 'ffn',
|
42 |
+
Original = 'original', // They will only be at level = 0
|
43 |
+
}
|
44 |
+
|
45 |
+
interface Node {
|
46 |
+
cell: Cell | null
|
47 |
+
item: CellItem | null
|
48 |
+
}
|
49 |
+
|
50 |
+
interface NodeProps {
|
51 |
+
node: Node
|
52 |
+
pos: Point
|
53 |
+
isActive: boolean
|
54 |
+
}
|
55 |
+
|
56 |
+
interface EdgeRaw {
|
57 |
+
weight: number
|
58 |
+
source: string
|
59 |
+
target: string
|
60 |
+
}
|
61 |
+
|
62 |
+
interface Edge {
|
63 |
+
weight: number
|
64 |
+
from: Node
|
65 |
+
to: Node
|
66 |
+
fromPos: Point
|
67 |
+
toPos: Point
|
68 |
+
isSelectable: boolean
|
69 |
+
isFfn: boolean
|
70 |
+
}
|
71 |
+
|
72 |
+
interface Selection {
|
73 |
+
node: Node | null
|
74 |
+
edge: Edge | null
|
75 |
+
}
|
76 |
+
|
77 |
+
function tokenPointerPolygon(origin: Point) {
|
78 |
+
const r = renderParams.tokenSelectorSize / 2
|
79 |
+
const dy = r / 2
|
80 |
+
const dx = r * Math.sqrt(3.0) / 2
|
81 |
+
// Draw an arrow looking down
|
82 |
+
return [
|
83 |
+
[origin.x, origin.y + r],
|
84 |
+
[origin.x + dx, origin.y - dy],
|
85 |
+
[origin.x - dx, origin.y - dy],
|
86 |
+
].toString()
|
87 |
+
}
|
88 |
+
|
89 |
+
function isSameCell(cell1: Cell | null, cell2: Cell | null) {
|
90 |
+
if (cell1 == null || cell2 == null) {
|
91 |
+
return false
|
92 |
+
}
|
93 |
+
return cell1.layer === cell2.layer && cell1.token === cell2.token
|
94 |
+
}
|
95 |
+
|
96 |
+
function isSameNode(node1: Node | null, node2: Node | null) {
|
97 |
+
if (node1 === null || node2 === null) {
|
98 |
+
return false
|
99 |
+
}
|
100 |
+
return isSameCell(node1.cell, node2.cell)
|
101 |
+
&& node1.item === node2.item;
|
102 |
+
}
|
103 |
+
|
104 |
+
function isSameEdge(edge1: Edge | null, edge2: Edge | null) {
|
105 |
+
if (edge1 === null || edge2 === null) {
|
106 |
+
return false
|
107 |
+
}
|
108 |
+
return isSameNode(edge1.from, edge2.from) && isSameNode(edge1.to, edge2.to);
|
109 |
+
}
|
110 |
+
|
111 |
+
function nodeFromString(name: string) {
|
112 |
+
const match = name.match(/([AIMX])(\d+)_(\d+)/)
|
113 |
+
if (match == null) {
|
114 |
+
return {
|
115 |
+
cell: null,
|
116 |
+
item: null,
|
117 |
+
}
|
118 |
+
}
|
119 |
+
const [, type, layerStr, tokenStr] = match
|
120 |
+
const layer = +layerStr
|
121 |
+
const token = +tokenStr
|
122 |
+
|
123 |
+
const typeToCellItem = new Map<string, CellItem>([
|
124 |
+
['A', CellItem.AfterAttn],
|
125 |
+
['I', CellItem.AfterFfn],
|
126 |
+
['M', CellItem.Ffn],
|
127 |
+
['X', CellItem.Original],
|
128 |
+
])
|
129 |
+
return {
|
130 |
+
cell: {
|
131 |
+
layer: layer,
|
132 |
+
token: token,
|
133 |
+
},
|
134 |
+
item: typeToCellItem.get(type) ?? null,
|
135 |
+
}
|
136 |
+
}
|
137 |
+
|
138 |
+
function isValidNode(node: Node, nLayers: number, nTokens: number) {
|
139 |
+
if (node.cell === null) {
|
140 |
+
return true
|
141 |
+
}
|
142 |
+
return node.cell.layer < nLayers && node.cell.token < nTokens
|
143 |
+
}
|
144 |
+
|
145 |
+
function isValidSelection(selection: Selection, nLayers: number, nTokens: number) {
|
146 |
+
if (selection.node !== null) {
|
147 |
+
return isValidNode(selection.node, nLayers, nTokens)
|
148 |
+
}
|
149 |
+
if (selection.edge !== null) {
|
150 |
+
return isValidNode(selection.edge.from, nLayers, nTokens) &&
|
151 |
+
isValidNode(selection.edge.to, nLayers, nTokens)
|
152 |
+
}
|
153 |
+
return true
|
154 |
+
}
|
155 |
+
|
156 |
+
const ContributionGraph = ({ args }: ComponentProps) => {
|
157 |
+
const modelInfo = args['model_info']
|
158 |
+
const tokens = args['tokens']
|
159 |
+
const edgesRaw: EdgeRaw[][] = args['edges_per_token']
|
160 |
+
|
161 |
+
const nLayers = modelInfo === null ? 0 : modelInfo.n_layers
|
162 |
+
const nTokens = tokens === null ? 0 : tokens.length
|
163 |
+
|
164 |
+
const [selection, setSelection] = useState<Selection>({
|
165 |
+
node: null,
|
166 |
+
edge: null,
|
167 |
+
})
|
168 |
+
var curSelection = selection
|
169 |
+
if (!isValidSelection(selection, nLayers, nTokens)) {
|
170 |
+
curSelection = {
|
171 |
+
node: null,
|
172 |
+
edge: null,
|
173 |
+
}
|
174 |
+
setSelection(curSelection)
|
175 |
+
Streamlit.setComponentValue(curSelection)
|
176 |
+
}
|
177 |
+
|
178 |
+
const [startToken, setStartToken] = useState<number>(nTokens - 1)
|
179 |
+
// We have startToken state var, but it won't be updated till next render, so use
|
180 |
+
// this var in the current render.
|
181 |
+
var curStartToken = startToken
|
182 |
+
if (startToken >= nTokens) {
|
183 |
+
curStartToken = nTokens - 1
|
184 |
+
setStartToken(curStartToken)
|
185 |
+
}
|
186 |
+
|
187 |
+
const handleRepresentationClick = (node: Node) => {
|
188 |
+
const newSelection: Selection = {
|
189 |
+
node: node,
|
190 |
+
edge: null,
|
191 |
+
}
|
192 |
+
setSelection(newSelection)
|
193 |
+
Streamlit.setComponentValue(newSelection)
|
194 |
+
}
|
195 |
+
|
196 |
+
const handleEdgeClick = (edge: Edge) => {
|
197 |
+
if (!edge.isSelectable) {
|
198 |
+
return
|
199 |
+
}
|
200 |
+
const newSelection: Selection = {
|
201 |
+
node: edge.to,
|
202 |
+
edge: edge,
|
203 |
+
}
|
204 |
+
setSelection(newSelection)
|
205 |
+
Streamlit.setComponentValue(newSelection)
|
206 |
+
}
|
207 |
+
|
208 |
+
const handleTokenClick = (t: number) => {
|
209 |
+
setStartToken(t)
|
210 |
+
}
|
211 |
+
|
212 |
+
const [xScale, yScale] = useMemo(() => {
|
213 |
+
const x = d3.scaleLinear()
|
214 |
+
.domain([-2, nTokens - 1])
|
215 |
+
.range([0, renderParams.cellW * (nTokens + 2)])
|
216 |
+
const y = d3.scaleLinear()
|
217 |
+
.domain([-1, nLayers])
|
218 |
+
.range([renderParams.cellH * (nLayers + 2), 0])
|
219 |
+
return [x, y]
|
220 |
+
}, [nLayers, nTokens])
|
221 |
+
|
222 |
+
const cells = useMemo(() => {
|
223 |
+
let result: Cell[] = []
|
224 |
+
for (let l = 0; l < nLayers; l++) {
|
225 |
+
for (let t = 0; t < nTokens; t++) {
|
226 |
+
result.push({
|
227 |
+
layer: l,
|
228 |
+
token: t,
|
229 |
+
})
|
230 |
+
}
|
231 |
+
}
|
232 |
+
return result
|
233 |
+
}, [nLayers, nTokens])
|
234 |
+
|
235 |
+
const nodeCoords = useMemo(() => {
|
236 |
+
let result = new Map<string, Point>()
|
237 |
+
const w = renderParams.cellW
|
238 |
+
const h = renderParams.cellH
|
239 |
+
for (var cell of cells) {
|
240 |
+
const cx = xScale(cell.token + 0.5)
|
241 |
+
const cy = yScale(cell.layer - 0.5)
|
242 |
+
result.set(
|
243 |
+
JSON.stringify({ cell: cell, item: CellItem.AfterAttn }),
|
244 |
+
{ x: cx, y: cy + h / 4 },
|
245 |
+
)
|
246 |
+
result.set(
|
247 |
+
JSON.stringify({ cell: cell, item: CellItem.AfterFfn }),
|
248 |
+
{ x: cx, y: cy - h / 4 },
|
249 |
+
)
|
250 |
+
result.set(
|
251 |
+
JSON.stringify({ cell: cell, item: CellItem.Ffn }),
|
252 |
+
{ x: cx + 5 * w / 16, y: cy },
|
253 |
+
)
|
254 |
+
}
|
255 |
+
for (let t = 0; t < nTokens; t++) {
|
256 |
+
cell = {
|
257 |
+
layer: 0,
|
258 |
+
token: t,
|
259 |
+
}
|
260 |
+
const cx = xScale(cell.token + 0.5)
|
261 |
+
const cy = yScale(cell.layer - 1.0)
|
262 |
+
result.set(
|
263 |
+
JSON.stringify({ cell: cell, item: CellItem.Original }),
|
264 |
+
{ x: cx, y: cy + h / 4 },
|
265 |
+
)
|
266 |
+
}
|
267 |
+
return result
|
268 |
+
}, [cells, nTokens, xScale, yScale])
|
269 |
+
|
270 |
+
const edges: Edge[][] = useMemo(() => {
|
271 |
+
let result = []
|
272 |
+
for (var edgeList of edgesRaw) {
|
273 |
+
let edgesPerStartToken = []
|
274 |
+
for (var edge of edgeList) {
|
275 |
+
const u = nodeFromString(edge.source)
|
276 |
+
const v = nodeFromString(edge.target)
|
277 |
+
var isSelectable = (
|
278 |
+
u.cell !== null && v.cell !== null && v.item === CellItem.AfterAttn
|
279 |
+
)
|
280 |
+
var isFfn = (
|
281 |
+
u.cell !== null && v.cell !== null && (
|
282 |
+
u.item === CellItem.Ffn || v.item === CellItem.Ffn
|
283 |
+
)
|
284 |
+
)
|
285 |
+
edgesPerStartToken.push({
|
286 |
+
weight: edge.weight,
|
287 |
+
from: u,
|
288 |
+
to: v,
|
289 |
+
fromPos: nodeCoords.get(JSON.stringify(u)) ?? { 'x': 0, 'y': 0 },
|
290 |
+
toPos: nodeCoords.get(JSON.stringify(v)) ?? { 'x': 0, 'y': 0 },
|
291 |
+
isSelectable: isSelectable,
|
292 |
+
isFfn: isFfn,
|
293 |
+
})
|
294 |
+
}
|
295 |
+
result.push(edgesPerStartToken)
|
296 |
+
}
|
297 |
+
return result
|
298 |
+
}, [edgesRaw, nodeCoords])
|
299 |
+
|
300 |
+
const activeNodes = useMemo(() => {
|
301 |
+
let result = new Set<string>()
|
302 |
+
for (var edge of edges[curStartToken]) {
|
303 |
+
const u = JSON.stringify(edge.from)
|
304 |
+
const v = JSON.stringify(edge.to)
|
305 |
+
result.add(u)
|
306 |
+
result.add(v)
|
307 |
+
}
|
308 |
+
return result
|
309 |
+
}, [edges, curStartToken])
|
310 |
+
|
311 |
+
const nodeProps = useMemo(() => {
|
312 |
+
let result: Array<NodeProps> = []
|
313 |
+
nodeCoords.forEach((p: Point, node: string) => {
|
314 |
+
result.push({
|
315 |
+
node: JSON.parse(node),
|
316 |
+
pos: p,
|
317 |
+
isActive: activeNodes.has(node),
|
318 |
+
})
|
319 |
+
})
|
320 |
+
return result
|
321 |
+
}, [nodeCoords, activeNodes])
|
322 |
+
|
323 |
+
const tokenLabels: Label[] = useMemo(() => {
|
324 |
+
if (!tokens) {
|
325 |
+
return []
|
326 |
+
}
|
327 |
+
return tokens.map((s: string, i: number) => ({
|
328 |
+
text: s.replace(/ /g, 'Β·'),
|
329 |
+
pos: {
|
330 |
+
x: xScale(i + 0.5),
|
331 |
+
y: yScale(-1.5),
|
332 |
+
},
|
333 |
+
}))
|
334 |
+
}, [tokens, xScale, yScale])
|
335 |
+
|
336 |
+
const layerLabels: Label[] = useMemo(() => {
|
337 |
+
return Array.from(Array(nLayers).keys()).map(i => ({
|
338 |
+
text: 'L' + i,
|
339 |
+
pos: {
|
340 |
+
x: xScale(-0.25),
|
341 |
+
y: yScale(i - 0.5),
|
342 |
+
},
|
343 |
+
}))
|
344 |
+
}, [nLayers, xScale, yScale])
|
345 |
+
|
346 |
+
const tokenSelectors: Array<[number, Point]> = useMemo(() => {
|
347 |
+
return Array.from(Array(nTokens).keys()).map(i => ([
|
348 |
+
i,
|
349 |
+
{
|
350 |
+
x: xScale(i + 0.5),
|
351 |
+
y: yScale(nLayers - 0.5),
|
352 |
+
}
|
353 |
+
]))
|
354 |
+
}, [nTokens, nLayers, xScale, yScale])
|
355 |
+
|
356 |
+
const totalW = xScale(nTokens + 2)
|
357 |
+
const totalH = yScale(-4)
|
358 |
+
useEffect(() => {
|
359 |
+
Streamlit.setFrameHeight(totalH)
|
360 |
+
}, [totalH])
|
361 |
+
|
362 |
+
const colorScale = d3.scaleLinear(
|
363 |
+
[0.0, 0.5, 1.0],
|
364 |
+
['#9eba66', 'darkolivegreen', 'darkolivegreen']
|
365 |
+
)
|
366 |
+
const ffnEdgeColorScale = d3.scaleLinear(
|
367 |
+
[0.0, 0.5, 1.0],
|
368 |
+
['orchid', 'purple', 'purple']
|
369 |
+
)
|
370 |
+
const edgeWidthScale = d3.scaleLinear([0.0, 0.5, 1.0], [2.0, 3.0, 3.0])
|
371 |
+
|
372 |
+
const svgRef = useRef(null);
|
373 |
+
|
374 |
+
useEffect(() => {
|
375 |
+
const getNodeStyle = (p: NodeProps, type: string) => {
|
376 |
+
if (isSameNode(p.node, curSelection.node)) {
|
377 |
+
return 'selectable-item selection'
|
378 |
+
}
|
379 |
+
if (p.isActive) {
|
380 |
+
return 'selectable-item active-' + type + '-node'
|
381 |
+
}
|
382 |
+
return 'selectable-item inactive-node'
|
383 |
+
}
|
384 |
+
|
385 |
+
const svg = d3.select(svgRef.current)
|
386 |
+
svg.selectAll('*').remove()
|
387 |
+
|
388 |
+
svg
|
389 |
+
.selectAll('layers')
|
390 |
+
.data(Array.from(Array(nLayers).keys()).filter((x) => x % 2 === 1))
|
391 |
+
.enter()
|
392 |
+
.append('rect')
|
393 |
+
.attr('class', 'layer-highlight')
|
394 |
+
.attr('x', xScale(-1.0))
|
395 |
+
.attr('y', (layer) => yScale(layer))
|
396 |
+
.attr('width', xScale(nTokens + 0.25) - xScale(-1.0))
|
397 |
+
.attr('height', (layer) => yScale(layer) - yScale(layer + 1))
|
398 |
+
.attr('rx', renderParams.layerCornerRadius)
|
399 |
+
|
400 |
+
svg
|
401 |
+
.selectAll('edges')
|
402 |
+
.data(edges[curStartToken])
|
403 |
+
.enter()
|
404 |
+
.append('line')
|
405 |
+
.style('stroke', (edge: Edge) => {
|
406 |
+
if (isSameEdge(edge, curSelection.edge)) {
|
407 |
+
return 'orange'
|
408 |
+
}
|
409 |
+
if (edge.isFfn) {
|
410 |
+
return ffnEdgeColorScale(edge.weight)
|
411 |
+
}
|
412 |
+
return colorScale(edge.weight)
|
413 |
+
})
|
414 |
+
.attr('class', (edge: Edge) => edge.isSelectable ? 'selectable-edge' : '')
|
415 |
+
.style('stroke-width', (edge: Edge) => edgeWidthScale(edge.weight))
|
416 |
+
.attr('x1', (edge: Edge) => edge.fromPos.x)
|
417 |
+
.attr('y1', (edge: Edge) => edge.fromPos.y)
|
418 |
+
.attr('x2', (edge: Edge) => edge.toPos.x)
|
419 |
+
.attr('y2', (edge: Edge) => edge.toPos.y)
|
420 |
+
.on('click', (event: PointerEvent, edge) => {
|
421 |
+
handleEdgeClick(edge)
|
422 |
+
})
|
423 |
+
|
424 |
+
svg
|
425 |
+
.selectAll('residual')
|
426 |
+
.data(nodeProps)
|
427 |
+
.enter()
|
428 |
+
.filter((p) => {
|
429 |
+
return p.node.item === CellItem.AfterAttn
|
430 |
+
|| p.node.item === CellItem.AfterFfn
|
431 |
+
})
|
432 |
+
.append('circle')
|
433 |
+
.attr('class', (p) => getNodeStyle(p, 'residual'))
|
434 |
+
.attr('cx', (p) => p.pos.x)
|
435 |
+
.attr('cy', (p) => p.pos.y)
|
436 |
+
.attr('r', renderParams.attnSize / 2)
|
437 |
+
.on('click', (event: PointerEvent, p) => {
|
438 |
+
handleRepresentationClick(p.node)
|
439 |
+
})
|
440 |
+
|
441 |
+
svg
|
442 |
+
.selectAll('ffn')
|
443 |
+
.data(nodeProps)
|
444 |
+
.enter()
|
445 |
+
.filter((p) => p.node.item === CellItem.Ffn && p.isActive)
|
446 |
+
.append('rect')
|
447 |
+
.attr('class', (p) => getNodeStyle(p, 'ffn'))
|
448 |
+
.attr('x', (p) => p.pos.x - renderParams.ffnSize / 2)
|
449 |
+
.attr('y', (p) => p.pos.y - renderParams.ffnSize / 2)
|
450 |
+
.attr('width', renderParams.ffnSize)
|
451 |
+
.attr('height', renderParams.ffnSize)
|
452 |
+
.on('click', (event: PointerEvent, p) => {
|
453 |
+
handleRepresentationClick(p.node)
|
454 |
+
})
|
455 |
+
|
456 |
+
svg
|
457 |
+
.selectAll('token_labels')
|
458 |
+
.data(tokenLabels)
|
459 |
+
.enter()
|
460 |
+
.append('text')
|
461 |
+
.attr('x', (label: Label) => label.pos.x)
|
462 |
+
.attr('y', (label: Label) => label.pos.y)
|
463 |
+
.attr('text-anchor', 'end')
|
464 |
+
.attr('dominant-baseline', 'middle')
|
465 |
+
.attr('alignment-baseline', 'top')
|
466 |
+
.attr('transform', (label: Label) =>
|
467 |
+
'rotate(-40, ' + label.pos.x + ', ' + label.pos.y + ')')
|
468 |
+
.text((label: Label) => label.text)
|
469 |
+
|
470 |
+
svg
|
471 |
+
.selectAll('layer_labels')
|
472 |
+
.data(layerLabels)
|
473 |
+
.enter()
|
474 |
+
.append('text')
|
475 |
+
.attr('x', (label: Label) => label.pos.x)
|
476 |
+
.attr('y', (label: Label) => label.pos.y)
|
477 |
+
.attr('text-anchor', 'middle')
|
478 |
+
.attr('alignment-baseline', 'middle')
|
479 |
+
.text((label: Label) => label.text)
|
480 |
+
|
481 |
+
svg
|
482 |
+
.selectAll('token_selectors')
|
483 |
+
.data(tokenSelectors)
|
484 |
+
.enter()
|
485 |
+
.append('polygon')
|
486 |
+
.attr('class', ([i,]) => (
|
487 |
+
curStartToken === i
|
488 |
+
? 'selectable-item selection'
|
489 |
+
: 'selectable-item token-selector'
|
490 |
+
))
|
491 |
+
.attr('points', ([, p]) => tokenPointerPolygon(p))
|
492 |
+
.attr('r', renderParams.tokenSelectorSize / 2)
|
493 |
+
.on('click', (event: PointerEvent, [i,]) => {
|
494 |
+
handleTokenClick(i)
|
495 |
+
})
|
496 |
+
}, [
|
497 |
+
cells,
|
498 |
+
edges,
|
499 |
+
nodeProps,
|
500 |
+
tokenLabels,
|
501 |
+
layerLabels,
|
502 |
+
tokenSelectors,
|
503 |
+
curStartToken,
|
504 |
+
curSelection,
|
505 |
+
colorScale,
|
506 |
+
ffnEdgeColorScale,
|
507 |
+
edgeWidthScale,
|
508 |
+
nLayers,
|
509 |
+
nTokens,
|
510 |
+
xScale,
|
511 |
+
yScale
|
512 |
+
])
|
513 |
+
|
514 |
+
return <svg ref={svgRef} width={totalW} height={totalH}></svg>
|
515 |
+
}
|
516 |
+
|
517 |
+
export default withStreamlitConnection(ContributionGraph)
|
llm_transparency_tool/components/frontend/src/LlmViewer.css
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
* All rights reserved.
|
4 |
+
*
|
5 |
+
* This source code is licensed under the license found in the
|
6 |
+
* LICENSE file in the root directory of this source tree.
|
7 |
+
*/
|
8 |
+
|
9 |
+
.graph-container {
|
10 |
+
display: flex;
|
11 |
+
justify-content: center;
|
12 |
+
align-items: center;
|
13 |
+
height: 100vh;
|
14 |
+
}
|
15 |
+
|
16 |
+
.svg {
|
17 |
+
border: 1px solid #ccc;
|
18 |
+
}
|
19 |
+
|
20 |
+
.layer-highlight {
|
21 |
+
fill: #f0f5f0;
|
22 |
+
}
|
23 |
+
|
24 |
+
.selectable-item {
|
25 |
+
stroke: black;
|
26 |
+
cursor: pointer;
|
27 |
+
}
|
28 |
+
|
29 |
+
.selection,
|
30 |
+
.selection:hover {
|
31 |
+
fill: orange;
|
32 |
+
}
|
33 |
+
|
34 |
+
.active-residual-node {
|
35 |
+
fill: yellowgreen;
|
36 |
+
}
|
37 |
+
|
38 |
+
.active-residual-node:hover {
|
39 |
+
fill: olivedrab;
|
40 |
+
}
|
41 |
+
|
42 |
+
.active-ffn-node {
|
43 |
+
fill: orchid;
|
44 |
+
}
|
45 |
+
|
46 |
+
.active-ffn-node:hover {
|
47 |
+
fill: purple;
|
48 |
+
}
|
49 |
+
|
50 |
+
.inactive-node {
|
51 |
+
fill: lightgray;
|
52 |
+
stroke-width: 0.5px;
|
53 |
+
}
|
54 |
+
|
55 |
+
.inactive-node:hover {
|
56 |
+
fill: gray;
|
57 |
+
}
|
58 |
+
|
59 |
+
.selectable-edge {
|
60 |
+
cursor: pointer;
|
61 |
+
}
|
62 |
+
|
63 |
+
.token-selector {
|
64 |
+
fill: lightblue;
|
65 |
+
}
|
66 |
+
|
67 |
+
.token-selector:hover {
|
68 |
+
fill: cornflowerblue;
|
69 |
+
}
|
70 |
+
|
71 |
+
.selector-item {
|
72 |
+
fill: lightblue;
|
73 |
+
}
|
74 |
+
|
75 |
+
.selector-item:hover {
|
76 |
+
fill: cornflowerblue;
|
77 |
+
}
|
llm_transparency_tool/components/frontend/src/Selector.tsx
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
* All rights reserved.
|
4 |
+
*
|
5 |
+
* This source code is licensed under the license found in the
|
6 |
+
* LICENSE file in the root directory of this source tree.
|
7 |
+
*/
|
8 |
+
|
9 |
+
import {
|
10 |
+
ComponentProps,
|
11 |
+
Streamlit,
|
12 |
+
withStreamlitConnection,
|
13 |
+
} from "streamlit-component-lib"
|
14 |
+
import React, { useEffect, useMemo, useRef, useState } from 'react';
|
15 |
+
import * as d3 from 'd3';
|
16 |
+
|
17 |
+
import {
|
18 |
+
Point,
|
19 |
+
} from './common';
|
20 |
+
import './LlmViewer.css';
|
21 |
+
|
22 |
+
export const renderParams = {
|
23 |
+
verticalGap: 24,
|
24 |
+
horizontalGap: 24,
|
25 |
+
itemSize: 8,
|
26 |
+
}
|
27 |
+
|
28 |
+
interface Item {
|
29 |
+
index: number
|
30 |
+
text: string
|
31 |
+
temperature: number
|
32 |
+
}
|
33 |
+
|
34 |
+
const Selector = ({ args }: ComponentProps) => {
|
35 |
+
const items: Item[] = args["items"]
|
36 |
+
const preselected_index: number | null = args["preselected_index"]
|
37 |
+
const n = items.length
|
38 |
+
|
39 |
+
const [selection, setSelection] = useState<number | null>(null)
|
40 |
+
|
41 |
+
// Ensure the preselected element has effect only when it's a new data.
|
42 |
+
var args_json = JSON.stringify(args)
|
43 |
+
useEffect(() => {
|
44 |
+
setSelection(preselected_index)
|
45 |
+
Streamlit.setComponentValue(preselected_index)
|
46 |
+
}, [args_json, preselected_index]);
|
47 |
+
|
48 |
+
const handleItemClick = (index: number) => {
|
49 |
+
setSelection(index)
|
50 |
+
Streamlit.setComponentValue(index)
|
51 |
+
}
|
52 |
+
|
53 |
+
const [xScale, yScale] = useMemo(() => {
|
54 |
+
const x = d3.scaleLinear()
|
55 |
+
.domain([0, 1])
|
56 |
+
.range([0, renderParams.horizontalGap])
|
57 |
+
const y = d3.scaleLinear()
|
58 |
+
.domain([0, n - 1])
|
59 |
+
.range([0, renderParams.verticalGap * (n - 1)])
|
60 |
+
return [x, y]
|
61 |
+
}, [n])
|
62 |
+
|
63 |
+
const itemCoords: Point[] = useMemo(() => {
|
64 |
+
return Array.from(Array(n).keys()).map(i => ({
|
65 |
+
x: xScale(0.5),
|
66 |
+
y: yScale(i + 0.5),
|
67 |
+
}))
|
68 |
+
}, [n, xScale, yScale])
|
69 |
+
|
70 |
+
var hasTemperature = false
|
71 |
+
if (n > 0) {
|
72 |
+
var t = items[0].temperature
|
73 |
+
hasTemperature = (t !== null && t !== undefined)
|
74 |
+
}
|
75 |
+
const colorScale = useMemo(() => {
|
76 |
+
var min_t = 0.0
|
77 |
+
var max_t = 1.0
|
78 |
+
if (hasTemperature) {
|
79 |
+
min_t = items[0].temperature
|
80 |
+
max_t = items[0].temperature
|
81 |
+
for (var i = 0; i < n; i++) {
|
82 |
+
const t = items[i].temperature
|
83 |
+
min_t = Math.min(min_t, t)
|
84 |
+
max_t = Math.max(max_t, t)
|
85 |
+
}
|
86 |
+
}
|
87 |
+
const norm = d3.scaleLinear([min_t, max_t], [0.0, 1.0])
|
88 |
+
const colorScale = d3.scaleSequential(d3.interpolateYlGn);
|
89 |
+
return d3.scaleSequential(value => colorScale(norm(value)))
|
90 |
+
}, [items, hasTemperature, n])
|
91 |
+
|
92 |
+
const totalW = 100
|
93 |
+
const totalH = yScale(n)
|
94 |
+
useEffect(() => {
|
95 |
+
Streamlit.setFrameHeight(totalH)
|
96 |
+
}, [totalH])
|
97 |
+
|
98 |
+
const svgRef = useRef(null);
|
99 |
+
|
100 |
+
useEffect(() => {
|
101 |
+
const svg = d3.select(svgRef.current)
|
102 |
+
svg.selectAll('*').remove()
|
103 |
+
|
104 |
+
const getItemClass = (index: number) => {
|
105 |
+
var style = 'selectable-item '
|
106 |
+
style += index === selection ? 'selection' : 'selector-item'
|
107 |
+
return style
|
108 |
+
}
|
109 |
+
|
110 |
+
const getItemColor = (item: Item) => {
|
111 |
+
var t = item.temperature ?? 0.0
|
112 |
+
return item.index === selection ? 'orange' : colorScale(t)
|
113 |
+
}
|
114 |
+
|
115 |
+
var icons = svg
|
116 |
+
.selectAll('items')
|
117 |
+
.data(Array.from(Array(n).keys()))
|
118 |
+
.enter()
|
119 |
+
.append('circle')
|
120 |
+
.attr('cx', (i) => itemCoords[i].x)
|
121 |
+
.attr('cy', (i) => itemCoords[i].y)
|
122 |
+
.attr('r', renderParams.itemSize / 2)
|
123 |
+
.on('click', (event: PointerEvent, i) => {
|
124 |
+
handleItemClick(items[i].index)
|
125 |
+
})
|
126 |
+
.attr('class', (i) => getItemClass(items[i].index))
|
127 |
+
if (hasTemperature) {
|
128 |
+
icons.style('fill', (i) => getItemColor(items[i]))
|
129 |
+
}
|
130 |
+
|
131 |
+
svg
|
132 |
+
.selectAll('labels')
|
133 |
+
.data(Array.from(Array(n).keys()))
|
134 |
+
.enter()
|
135 |
+
.append('text')
|
136 |
+
.attr('x', (i) => itemCoords[i].x + renderParams.horizontalGap / 2)
|
137 |
+
.attr('y', (i) => itemCoords[i].y)
|
138 |
+
.attr('text-anchor', 'left')
|
139 |
+
.attr('alignment-baseline', 'middle')
|
140 |
+
.text((i) => items[i].text)
|
141 |
+
|
142 |
+
}, [
|
143 |
+
items,
|
144 |
+
n,
|
145 |
+
itemCoords,
|
146 |
+
selection,
|
147 |
+
colorScale,
|
148 |
+
hasTemperature,
|
149 |
+
])
|
150 |
+
|
151 |
+
return <svg ref={svgRef} width={totalW} height={totalH}></svg>
|
152 |
+
}
|
153 |
+
|
154 |
+
export default withStreamlitConnection(Selector)
|
llm_transparency_tool/components/frontend/src/common.tsx
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
* All rights reserved.
|
4 |
+
*
|
5 |
+
* This source code is licensed under the license found in the
|
6 |
+
* LICENSE file in the root directory of this source tree.
|
7 |
+
*/
|
8 |
+
|
9 |
+
export interface Point {
|
10 |
+
x: number
|
11 |
+
y: number
|
12 |
+
}
|
13 |
+
|
14 |
+
export interface Label {
|
15 |
+
text: string
|
16 |
+
pos: Point
|
17 |
+
}
|
llm_transparency_tool/components/frontend/src/index.tsx
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
* All rights reserved.
|
4 |
+
*
|
5 |
+
* This source code is licensed under the license found in the
|
6 |
+
* LICENSE file in the root directory of this source tree.
|
7 |
+
*/
|
8 |
+
|
9 |
+
import React from "react"
|
10 |
+
import ReactDOM from "react-dom"
|
11 |
+
|
12 |
+
import {
|
13 |
+
ComponentProps,
|
14 |
+
withStreamlitConnection,
|
15 |
+
} from "streamlit-component-lib"
|
16 |
+
|
17 |
+
|
18 |
+
import ContributionGraph from "./ContributionGraph"
|
19 |
+
import Selector from "./Selector"
|
20 |
+
|
21 |
+
const LlmViewerComponent = (props: ComponentProps) => {
|
22 |
+
switch (props.args['component']) {
|
23 |
+
case 'graph':
|
24 |
+
return <ContributionGraph />
|
25 |
+
case 'selector':
|
26 |
+
return <Selector />
|
27 |
+
default:
|
28 |
+
return <></>
|
29 |
+
}
|
30 |
+
};
|
31 |
+
|
32 |
+
const StreamlitLlmViewerComponent = withStreamlitConnection(LlmViewerComponent)
|
33 |
+
|
34 |
+
ReactDOM.render(
|
35 |
+
<React.StrictMode>
|
36 |
+
<StreamlitLlmViewerComponent />
|
37 |
+
</React.StrictMode>,
|
38 |
+
document.getElementById("root")
|
39 |
+
)
|
llm_transparency_tool/components/frontend/src/react-app-env.d.ts
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
/// <reference types="react-scripts" />
|
llm_transparency_tool/components/frontend/tsconfig.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"compilerOptions": {
|
3 |
+
"target": "es5",
|
4 |
+
"lib": ["dom", "dom.iterable", "esnext"],
|
5 |
+
"allowJs": true,
|
6 |
+
"skipLibCheck": true,
|
7 |
+
"esModuleInterop": true,
|
8 |
+
"allowSyntheticDefaultImports": true,
|
9 |
+
"strict": true,
|
10 |
+
"forceConsistentCasingInFileNames": true,
|
11 |
+
"module": "esnext",
|
12 |
+
"moduleResolution": "node",
|
13 |
+
"resolveJsonModule": true,
|
14 |
+
"isolatedModules": true,
|
15 |
+
"noEmit": true,
|
16 |
+
"jsx": "react"
|
17 |
+
},
|
18 |
+
"include": ["src"]
|
19 |
+
}
|
llm_transparency_tool/models/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
llm_transparency_tool/models/test_tlens_model.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import unittest
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm
|
12 |
+
from llm_transparency_tool.models.transparent_llm import ModelInfo
|
13 |
+
|
14 |
+
|
15 |
+
class TransparentLlmTestCase(unittest.TestCase):
|
16 |
+
@classmethod
|
17 |
+
def setUpClass(cls):
|
18 |
+
# Picking the smallest model possible so that the test runs faster. It's ok to
|
19 |
+
# change this model, but you'll need to update tokenization specifics in some
|
20 |
+
# tests.
|
21 |
+
cls._llm = TransformerLensTransparentLlm(
|
22 |
+
model_name="facebook/opt-125m",
|
23 |
+
device="cpu",
|
24 |
+
)
|
25 |
+
|
26 |
+
def setUp(self):
|
27 |
+
self._llm.run(["test", "test 1"])
|
28 |
+
self._eps = 1e-5
|
29 |
+
|
30 |
+
def test_model_info(self):
|
31 |
+
info = self._llm.model_info()
|
32 |
+
self.assertEqual(
|
33 |
+
info,
|
34 |
+
ModelInfo(
|
35 |
+
name="facebook/opt-125m",
|
36 |
+
n_params_estimate=84934656,
|
37 |
+
n_layers=12,
|
38 |
+
n_heads=12,
|
39 |
+
d_model=768,
|
40 |
+
d_vocab=50272,
|
41 |
+
),
|
42 |
+
)
|
43 |
+
|
44 |
+
def test_tokens(self):
|
45 |
+
tokens = self._llm.tokens()
|
46 |
+
|
47 |
+
pad = 1
|
48 |
+
bos = 2
|
49 |
+
test = 21959
|
50 |
+
one = 112
|
51 |
+
|
52 |
+
self.assertEqual(tokens.tolist(), [[bos, test, pad], [bos, test, one]])
|
53 |
+
|
54 |
+
def test_tokens_to_strings(self):
|
55 |
+
s = self._llm.tokens_to_strings(torch.Tensor([2, 21959, 112]).to(torch.int))
|
56 |
+
self.assertEqual(s, ["</s>", "test", " 1"])
|
57 |
+
|
58 |
+
def test_manage_state(self):
|
59 |
+
# One llm.run was called at the setup. Call one more and make sure the object
|
60 |
+
# returns values for the new state.
|
61 |
+
self._llm.run(["one", "two", "three", "four"])
|
62 |
+
self.assertEqual(self._llm.tokens().shape[0], 4)
|
63 |
+
|
64 |
+
def test_residual_in_and_out(self):
|
65 |
+
"""
|
66 |
+
Test that residual_in is a residual_out for the previous layer.
|
67 |
+
"""
|
68 |
+
for layer in range(1, 12):
|
69 |
+
prev_residual_out = self._llm.residual_out(layer - 1)
|
70 |
+
residual_in = self._llm.residual_in(layer)
|
71 |
+
diff = torch.max(torch.abs(residual_in - prev_residual_out)).item()
|
72 |
+
self.assertLess(diff, self._eps, f"layer {layer}")
|
73 |
+
|
74 |
+
def test_residual_plus_block(self):
|
75 |
+
"""
|
76 |
+
Make sure that new residual = old residual + block output. Here, block is an ffn
|
77 |
+
or attention. It's not that obvious because it could be that layer norm is
|
78 |
+
applied after the block output, but before saving the result to residual.
|
79 |
+
Luckily, this is not the case in TransformerLens, and we're relying on that.
|
80 |
+
"""
|
81 |
+
layer = 3
|
82 |
+
batch = 0
|
83 |
+
pos = 0
|
84 |
+
|
85 |
+
residual_in = self._llm.residual_in(layer)[batch][pos]
|
86 |
+
residual_mid = self._llm.residual_after_attn(layer)[batch][pos]
|
87 |
+
residual_out = self._llm.residual_out(layer)[batch][pos]
|
88 |
+
ffn_out = self._llm.ffn_out(layer)[batch][pos]
|
89 |
+
attn_out = self._llm.attention_output(batch, layer, pos)
|
90 |
+
|
91 |
+
a = residual_mid
|
92 |
+
b = residual_in + attn_out
|
93 |
+
diff = torch.max(torch.abs(a - b)).item()
|
94 |
+
self.assertLess(diff, self._eps, "attn")
|
95 |
+
|
96 |
+
a = residual_out
|
97 |
+
b = residual_mid + ffn_out
|
98 |
+
diff = torch.max(torch.abs(a - b)).item()
|
99 |
+
self.assertLess(diff, self._eps, "ffn")
|
100 |
+
|
101 |
+
def test_tensor_shapes(self):
|
102 |
+
# Not much we can do about the tensors, but at least check their shapes and
|
103 |
+
# that they don't contain NaNs.
|
104 |
+
vocab_size = 50272
|
105 |
+
n_batch = 2
|
106 |
+
n_tokens = 3
|
107 |
+
d_model = 768
|
108 |
+
d_hidden = d_model * 4
|
109 |
+
n_heads = 12
|
110 |
+
layer = 5
|
111 |
+
|
112 |
+
device = self._llm.residual_in(0).device
|
113 |
+
|
114 |
+
for name, tensor, expected_shape in [
|
115 |
+
("r_in", self._llm.residual_in(layer), [n_batch, n_tokens, d_model]),
|
116 |
+
(
|
117 |
+
"r_mid",
|
118 |
+
self._llm.residual_after_attn(layer),
|
119 |
+
[n_batch, n_tokens, d_model],
|
120 |
+
),
|
121 |
+
("r_out", self._llm.residual_out(layer), [n_batch, n_tokens, d_model]),
|
122 |
+
("logits", self._llm.logits(), [n_batch, n_tokens, vocab_size]),
|
123 |
+
("ffn_out", self._llm.ffn_out(layer), [n_batch, n_tokens, d_model]),
|
124 |
+
(
|
125 |
+
"decomposed_ffn_out",
|
126 |
+
self._llm.decomposed_ffn_out(0, 0, 0),
|
127 |
+
[d_hidden, d_model],
|
128 |
+
),
|
129 |
+
("neuron_activations", self._llm.neuron_activations(0, 0, 0), [d_hidden]),
|
130 |
+
("neuron_output", self._llm.neuron_output(0, 0), [d_model]),
|
131 |
+
(
|
132 |
+
"attention_matrix",
|
133 |
+
self._llm.attention_matrix(0, 0, 0),
|
134 |
+
[n_tokens, n_tokens],
|
135 |
+
),
|
136 |
+
(
|
137 |
+
"attention_output_per_head",
|
138 |
+
self._llm.attention_output_per_head(0, 0, 0, 0),
|
139 |
+
[d_model],
|
140 |
+
),
|
141 |
+
(
|
142 |
+
"attention_output",
|
143 |
+
self._llm.attention_output(0, 0, 0),
|
144 |
+
[d_model],
|
145 |
+
),
|
146 |
+
(
|
147 |
+
"decomposed_attn",
|
148 |
+
self._llm.decomposed_attn(0, layer),
|
149 |
+
[n_tokens, n_tokens, n_heads, d_model],
|
150 |
+
),
|
151 |
+
(
|
152 |
+
"unembed",
|
153 |
+
self._llm.unembed(torch.zeros([d_model]).to(device), normalize=True),
|
154 |
+
[vocab_size],
|
155 |
+
),
|
156 |
+
]:
|
157 |
+
self.assertEqual(list(tensor.shape), expected_shape, name)
|
158 |
+
self.assertFalse(torch.any(tensor.isnan()), name)
|
159 |
+
|
160 |
+
|
161 |
+
if __name__ == "__main__":
|
162 |
+
unittest.main()
|
llm_transparency_tool/models/tlens_model.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import List, Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import transformer_lens
|
12 |
+
import transformers
|
13 |
+
from fancy_einsum import einsum
|
14 |
+
from jaxtyping import Float, Int
|
15 |
+
from typeguard import typechecked
|
16 |
+
import streamlit as st
|
17 |
+
|
18 |
+
from llm_transparency_tool.models.transparent_llm import ModelInfo, TransparentLlm
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class _RunInfo:
|
23 |
+
tokens: Int[torch.Tensor, "batch pos"]
|
24 |
+
logits: Float[torch.Tensor, "batch pos d_vocab"]
|
25 |
+
cache: transformer_lens.ActivationCache
|
26 |
+
|
27 |
+
|
28 |
+
@st.cache_resource(
|
29 |
+
max_entries=1,
|
30 |
+
show_spinner=True,
|
31 |
+
hash_funcs={
|
32 |
+
transformers.PreTrainedModel: id,
|
33 |
+
transformers.PreTrainedTokenizer: id
|
34 |
+
}
|
35 |
+
)
|
36 |
+
def load_hooked_transformer(
|
37 |
+
model_name: str,
|
38 |
+
hf_model: Optional[transformers.PreTrainedModel] = None,
|
39 |
+
tlens_device: str = "cuda",
|
40 |
+
dtype: torch.dtype = torch.float32,
|
41 |
+
):
|
42 |
+
# if tlens_device == "cuda":
|
43 |
+
# n_devices = torch.cuda.device_count()
|
44 |
+
# else:
|
45 |
+
# n_devices = 1
|
46 |
+
tlens_model = transformer_lens.HookedTransformer.from_pretrained(
|
47 |
+
model_name,
|
48 |
+
hf_model=hf_model,
|
49 |
+
fold_ln=False, # Keep layer norm where it is.
|
50 |
+
center_writing_weights=False,
|
51 |
+
center_unembed=False,
|
52 |
+
device=tlens_device,
|
53 |
+
# n_devices=n_devices,
|
54 |
+
dtype=dtype,
|
55 |
+
)
|
56 |
+
tlens_model.eval()
|
57 |
+
return tlens_model
|
58 |
+
|
59 |
+
|
60 |
+
# TODO(igortufanov): If we want to scale the app to multiple users, we need more careful
|
61 |
+
# thread-safe implementation. The simplest option could be to wrap the existing methods
|
62 |
+
# in mutexes.
|
63 |
+
class TransformerLensTransparentLlm(TransparentLlm):
|
64 |
+
"""
|
65 |
+
Implementation of Transparent LLM based on transformer lens.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
- model_name: The official name of the model from HuggingFace. Even if the model was
|
69 |
+
patched or loaded locally, the name should still be official because that's how
|
70 |
+
transformer_lens treats the model.
|
71 |
+
- hf_model: The language model as a HuggingFace class.
|
72 |
+
- tokenizer,
|
73 |
+
- device: "gpu" or "cpu"
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
model_name: str,
|
79 |
+
hf_model: Optional[transformers.PreTrainedModel] = None,
|
80 |
+
tokenizer: Optional[transformers.PreTrainedTokenizer] = None,
|
81 |
+
device: str = "gpu",
|
82 |
+
dtype: torch.dtype = torch.float32,
|
83 |
+
):
|
84 |
+
if device == "gpu":
|
85 |
+
self.device = "cuda"
|
86 |
+
if not torch.cuda.is_available():
|
87 |
+
RuntimeError("Asked to run on gpu, but torch couldn't find cuda")
|
88 |
+
elif device == "cpu":
|
89 |
+
self.device = "cpu"
|
90 |
+
else:
|
91 |
+
raise RuntimeError(f"Specified device {device} is not a valid option")
|
92 |
+
|
93 |
+
self.dtype = dtype
|
94 |
+
self.hf_tokenizer = tokenizer
|
95 |
+
self.hf_model = hf_model
|
96 |
+
|
97 |
+
# self._model = tlens_model
|
98 |
+
self._model_name = model_name
|
99 |
+
self._prepend_bos = True
|
100 |
+
self._last_run = None
|
101 |
+
self._run_exception = RuntimeError(
|
102 |
+
"Tried to use the model output before calling the `run` method"
|
103 |
+
)
|
104 |
+
|
105 |
+
def copy(self):
|
106 |
+
import copy
|
107 |
+
return copy.copy(self)
|
108 |
+
|
109 |
+
@property
|
110 |
+
def _model(self):
|
111 |
+
tlens_model = load_hooked_transformer(
|
112 |
+
self._model_name,
|
113 |
+
hf_model=self.hf_model,
|
114 |
+
tlens_device=self.device,
|
115 |
+
dtype=self.dtype,
|
116 |
+
)
|
117 |
+
|
118 |
+
if self.hf_tokenizer is not None:
|
119 |
+
tlens_model.set_tokenizer(self.hf_tokenizer, default_padding_side="left")
|
120 |
+
|
121 |
+
tlens_model.set_use_attn_result(True)
|
122 |
+
tlens_model.set_use_attn_in(False)
|
123 |
+
tlens_model.set_use_split_qkv_input(False)
|
124 |
+
|
125 |
+
return tlens_model
|
126 |
+
|
127 |
+
def model_info(self) -> ModelInfo:
|
128 |
+
cfg = self._model.cfg
|
129 |
+
return ModelInfo(
|
130 |
+
name=self._model_name,
|
131 |
+
n_params_estimate=cfg.n_params,
|
132 |
+
n_layers=cfg.n_layers,
|
133 |
+
n_heads=cfg.n_heads,
|
134 |
+
d_model=cfg.d_model,
|
135 |
+
d_vocab=cfg.d_vocab,
|
136 |
+
)
|
137 |
+
|
138 |
+
@torch.no_grad()
|
139 |
+
def run(self, sentences: List[str]) -> None:
|
140 |
+
tokens = self._model.to_tokens(sentences, prepend_bos=self._prepend_bos)
|
141 |
+
logits, cache = self._model.run_with_cache(tokens)
|
142 |
+
|
143 |
+
self._last_run = _RunInfo(
|
144 |
+
tokens=tokens,
|
145 |
+
logits=logits,
|
146 |
+
cache=cache,
|
147 |
+
)
|
148 |
+
|
149 |
+
def batch_size(self) -> int:
|
150 |
+
if not self._last_run:
|
151 |
+
raise self._run_exception
|
152 |
+
return self._last_run.logits.shape[0]
|
153 |
+
|
154 |
+
@typechecked
|
155 |
+
def tokens(self) -> Int[torch.Tensor, "batch pos"]:
|
156 |
+
if not self._last_run:
|
157 |
+
raise self._run_exception
|
158 |
+
return self._last_run.tokens
|
159 |
+
|
160 |
+
@typechecked
|
161 |
+
def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
|
162 |
+
return self._model.to_str_tokens(tokens)
|
163 |
+
|
164 |
+
@typechecked
|
165 |
+
def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
|
166 |
+
if not self._last_run:
|
167 |
+
raise self._run_exception
|
168 |
+
return self._last_run.logits
|
169 |
+
|
170 |
+
@torch.no_grad()
|
171 |
+
@typechecked
|
172 |
+
def unembed(
|
173 |
+
self,
|
174 |
+
t: Float[torch.Tensor, "d_model"],
|
175 |
+
normalize: bool,
|
176 |
+
) -> Float[torch.Tensor, "vocab"]:
|
177 |
+
# t: [d_model] -> [batch, pos, d_model]
|
178 |
+
tdim = t.unsqueeze(0).unsqueeze(0)
|
179 |
+
if normalize:
|
180 |
+
normalized = self._model.ln_final(tdim)
|
181 |
+
result = self._model.unembed(normalized)
|
182 |
+
else:
|
183 |
+
result = self._model.unembed(tdim)
|
184 |
+
return result[0][0]
|
185 |
+
|
186 |
+
def _get_block(self, layer: int, block_name: str) -> str:
|
187 |
+
if not self._last_run:
|
188 |
+
raise self._run_exception
|
189 |
+
return self._last_run.cache[f"blocks.{layer}.{block_name}"]
|
190 |
+
|
191 |
+
# ================= Methods related to the residual stream =================
|
192 |
+
|
193 |
+
@typechecked
|
194 |
+
def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
|
195 |
+
if not self._last_run:
|
196 |
+
raise self._run_exception
|
197 |
+
return self._get_block(layer, "hook_resid_pre")
|
198 |
+
|
199 |
+
@typechecked
|
200 |
+
def residual_after_attn(
|
201 |
+
self, layer: int
|
202 |
+
) -> Float[torch.Tensor, "batch pos d_model"]:
|
203 |
+
if not self._last_run:
|
204 |
+
raise self._run_exception
|
205 |
+
return self._get_block(layer, "hook_resid_mid")
|
206 |
+
|
207 |
+
@typechecked
|
208 |
+
def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
|
209 |
+
if not self._last_run:
|
210 |
+
raise self._run_exception
|
211 |
+
return self._get_block(layer, "hook_resid_post")
|
212 |
+
|
213 |
+
# ================ Methods related to the feed-forward layer ===============
|
214 |
+
|
215 |
+
@typechecked
|
216 |
+
def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
|
217 |
+
if not self._last_run:
|
218 |
+
raise self._run_exception
|
219 |
+
return self._get_block(layer, "hook_mlp_out")
|
220 |
+
|
221 |
+
@torch.no_grad()
|
222 |
+
@typechecked
|
223 |
+
def decomposed_ffn_out(
|
224 |
+
self,
|
225 |
+
batch_i: int,
|
226 |
+
layer: int,
|
227 |
+
pos: int,
|
228 |
+
) -> Float[torch.Tensor, "hidden d_model"]:
|
229 |
+
# Take activations right before they're multiplied by W_out, i.e. non-linearity
|
230 |
+
# and layer norm are already applied.
|
231 |
+
processed_activations = self._get_block(layer, "mlp.hook_post")[batch_i][pos]
|
232 |
+
return torch.mul(processed_activations.unsqueeze(-1), self._model.W_out[layer])
|
233 |
+
|
234 |
+
@typechecked
|
235 |
+
def neuron_activations(
|
236 |
+
self,
|
237 |
+
batch_i: int,
|
238 |
+
layer: int,
|
239 |
+
pos: int,
|
240 |
+
) -> Float[torch.Tensor, "hidden"]:
|
241 |
+
return self._get_block(layer, "mlp.hook_pre")[batch_i][pos]
|
242 |
+
|
243 |
+
@typechecked
|
244 |
+
def neuron_output(
|
245 |
+
self,
|
246 |
+
layer: int,
|
247 |
+
neuron: int,
|
248 |
+
) -> Float[torch.Tensor, "d_model"]:
|
249 |
+
return self._model.W_out[layer][neuron]
|
250 |
+
|
251 |
+
# ==================== Methods related to the attention ====================
|
252 |
+
|
253 |
+
@typechecked
|
254 |
+
def attention_matrix(
|
255 |
+
self, batch_i: int, layer: int, head: int
|
256 |
+
) -> Float[torch.Tensor, "query_pos key_pos"]:
|
257 |
+
return self._get_block(layer, "attn.hook_pattern")[batch_i][head]
|
258 |
+
|
259 |
+
@typechecked
|
260 |
+
def attention_output_per_head(
|
261 |
+
self,
|
262 |
+
batch_i: int,
|
263 |
+
layer: int,
|
264 |
+
pos: int,
|
265 |
+
head: int,
|
266 |
+
) -> Float[torch.Tensor, "d_model"]:
|
267 |
+
return self._get_block(layer, "attn.hook_result")[batch_i][pos][head]
|
268 |
+
|
269 |
+
@typechecked
|
270 |
+
def attention_output(
|
271 |
+
self,
|
272 |
+
batch_i: int,
|
273 |
+
layer: int,
|
274 |
+
pos: int,
|
275 |
+
) -> Float[torch.Tensor, "d_model"]:
|
276 |
+
return self._get_block(layer, "hook_attn_out")[batch_i][pos]
|
277 |
+
|
278 |
+
@torch.no_grad()
|
279 |
+
@typechecked
|
280 |
+
def decomposed_attn(
|
281 |
+
self, batch_i: int, layer: int
|
282 |
+
) -> Float[torch.Tensor, "pos key_pos head d_model"]:
|
283 |
+
if not self._last_run:
|
284 |
+
raise self._run_exception
|
285 |
+
hook_v = self._get_block(layer, "attn.hook_v")[batch_i]
|
286 |
+
b_v = self._model.b_V[layer]
|
287 |
+
v = hook_v + b_v
|
288 |
+
pattern = self._get_block(layer, "attn.hook_pattern")[batch_i].to(v.dtype)
|
289 |
+
z = einsum(
|
290 |
+
"key_pos head d_head, "
|
291 |
+
"head query_pos key_pos -> "
|
292 |
+
"query_pos key_pos head d_head",
|
293 |
+
v,
|
294 |
+
pattern,
|
295 |
+
)
|
296 |
+
decomposed_attn = einsum(
|
297 |
+
"pos key_pos head d_head, "
|
298 |
+
"head d_head d_model -> "
|
299 |
+
"pos key_pos head d_model",
|
300 |
+
z,
|
301 |
+
self._model.W_O[layer],
|
302 |
+
)
|
303 |
+
return decomposed_attn
|
llm_transparency_tool/models/transparent_llm.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import List
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from jaxtyping import Float, Int
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class ModelInfo:
|
17 |
+
name: str
|
18 |
+
|
19 |
+
# Not the actual number of parameters, but rather the order of magnitude
|
20 |
+
n_params_estimate: int
|
21 |
+
|
22 |
+
n_layers: int
|
23 |
+
n_heads: int
|
24 |
+
d_model: int
|
25 |
+
d_vocab: int
|
26 |
+
|
27 |
+
|
28 |
+
class TransparentLlm(ABC):
|
29 |
+
"""
|
30 |
+
An abstract stateful interface for a language model. The model is supposed to be
|
31 |
+
loaded at the class initialization.
|
32 |
+
|
33 |
+
The internal state is the resulting tensors from the last call of the `run` method.
|
34 |
+
Most of the methods could return values based on the state, but some may do cheap
|
35 |
+
computations based on them.
|
36 |
+
"""
|
37 |
+
|
38 |
+
@abstractmethod
|
39 |
+
def model_info(self) -> ModelInfo:
|
40 |
+
"""
|
41 |
+
Gives general info about the model. This method must be available before any
|
42 |
+
calls of the `run`.
|
43 |
+
"""
|
44 |
+
pass
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def run(self, sentences: List[str]) -> None:
|
48 |
+
"""
|
49 |
+
Run the inference on the given sentences in a single batch and store all
|
50 |
+
necessary info in the internal state.
|
51 |
+
"""
|
52 |
+
pass
|
53 |
+
|
54 |
+
@abstractmethod
|
55 |
+
def batch_size(self) -> int:
|
56 |
+
"""
|
57 |
+
The size of the batch that was used for the last call of `run`.
|
58 |
+
"""
|
59 |
+
pass
|
60 |
+
|
61 |
+
@abstractmethod
|
62 |
+
def tokens(self) -> Int[torch.Tensor, "batch pos"]:
|
63 |
+
pass
|
64 |
+
|
65 |
+
@abstractmethod
|
66 |
+
def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
|
67 |
+
pass
|
68 |
+
|
69 |
+
@abstractmethod
|
70 |
+
def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
|
71 |
+
pass
|
72 |
+
|
73 |
+
@abstractmethod
|
74 |
+
def unembed(
|
75 |
+
self,
|
76 |
+
t: Float[torch.Tensor, "d_model"],
|
77 |
+
normalize: bool,
|
78 |
+
) -> Float[torch.Tensor, "vocab"]:
|
79 |
+
"""
|
80 |
+
Project the given vector (for example, the state of the residual stream for a
|
81 |
+
layer and token) into the output vocabulary.
|
82 |
+
|
83 |
+
normalize: whether to apply the final normalization before the unembedding.
|
84 |
+
Setting it to True and applying to output of the last layer gives the output of
|
85 |
+
the model.
|
86 |
+
"""
|
87 |
+
pass
|
88 |
+
|
89 |
+
# ================= Methods related to the residual stream =================
|
90 |
+
|
91 |
+
@abstractmethod
|
92 |
+
def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
|
93 |
+
"""
|
94 |
+
The state of the residual stream before entering the layer. For example, when
|
95 |
+
layer == 0 these must the embedded tokens (including positional embedding).
|
96 |
+
"""
|
97 |
+
pass
|
98 |
+
|
99 |
+
@abstractmethod
|
100 |
+
def residual_after_attn(
|
101 |
+
self, layer: int
|
102 |
+
) -> Float[torch.Tensor, "batch pos d_model"]:
|
103 |
+
"""
|
104 |
+
The state of the residual stream after attention, but before the FFN in the
|
105 |
+
given layer.
|
106 |
+
"""
|
107 |
+
pass
|
108 |
+
|
109 |
+
@abstractmethod
|
110 |
+
def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
|
111 |
+
"""
|
112 |
+
The state of the residual stream after the given layer. This is equivalent to the
|
113 |
+
next layer's input.
|
114 |
+
"""
|
115 |
+
pass
|
116 |
+
|
117 |
+
# ================ Methods related to the feed-forward layer ===============
|
118 |
+
|
119 |
+
@abstractmethod
|
120 |
+
def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
|
121 |
+
"""
|
122 |
+
The output of the FFN layer, before it gets merged into the residual stream.
|
123 |
+
"""
|
124 |
+
pass
|
125 |
+
|
126 |
+
@abstractmethod
|
127 |
+
def decomposed_ffn_out(
|
128 |
+
self,
|
129 |
+
batch_i: int,
|
130 |
+
layer: int,
|
131 |
+
pos: int,
|
132 |
+
) -> Float[torch.Tensor, "hidden d_model"]:
|
133 |
+
"""
|
134 |
+
A collection of vectors added to the residual stream by each neuron. It should
|
135 |
+
be the same as neuron activations multiplied by neuron outputs.
|
136 |
+
"""
|
137 |
+
pass
|
138 |
+
|
139 |
+
@abstractmethod
|
140 |
+
def neuron_activations(
|
141 |
+
self,
|
142 |
+
batch_i: int,
|
143 |
+
layer: int,
|
144 |
+
pos: int,
|
145 |
+
) -> Float[torch.Tensor, "d_ffn"]:
|
146 |
+
"""
|
147 |
+
The content of the hidden layer right after the activation function was applied.
|
148 |
+
"""
|
149 |
+
pass
|
150 |
+
|
151 |
+
@abstractmethod
|
152 |
+
def neuron_output(
|
153 |
+
self,
|
154 |
+
layer: int,
|
155 |
+
neuron: int,
|
156 |
+
) -> Float[torch.Tensor, "d_model"]:
|
157 |
+
"""
|
158 |
+
Return the value that the given neuron adds to the residual stream. It's a raw
|
159 |
+
vector from the model parameters, no activation involved.
|
160 |
+
"""
|
161 |
+
pass
|
162 |
+
|
163 |
+
# ==================== Methods related to the attention ====================
|
164 |
+
|
165 |
+
@abstractmethod
|
166 |
+
def attention_matrix(
|
167 |
+
self, batch_i, layer: int, head: int
|
168 |
+
) -> Float[torch.Tensor, "query_pos key_pos"]:
|
169 |
+
"""
|
170 |
+
Return a lower-diagonal attention matrix.
|
171 |
+
"""
|
172 |
+
pass
|
173 |
+
|
174 |
+
@abstractmethod
|
175 |
+
def attention_output(
|
176 |
+
self,
|
177 |
+
batch_i: int,
|
178 |
+
layer: int,
|
179 |
+
pos: int,
|
180 |
+
head: int,
|
181 |
+
) -> Float[torch.Tensor, "d_model"]:
|
182 |
+
"""
|
183 |
+
Return what the given head at the given layer and pos added to the residual
|
184 |
+
stream.
|
185 |
+
"""
|
186 |
+
pass
|
187 |
+
|
188 |
+
@abstractmethod
|
189 |
+
def decomposed_attn(
|
190 |
+
self, batch_i: int, layer: int
|
191 |
+
) -> Float[torch.Tensor, "source target head d_model"]:
|
192 |
+
"""
|
193 |
+
Here
|
194 |
+
- source: index of token from the previous layer
|
195 |
+
- target: index of token on the current layer
|
196 |
+
The decomposed attention tells what vector from source representation was used
|
197 |
+
in order to contribute to the taget representation.
|
198 |
+
"""
|
199 |
+
pass
|
llm_transparency_tool/routes/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
llm_transparency_tool/routes/contributions.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Tuple
|
8 |
+
|
9 |
+
import einops
|
10 |
+
import torch
|
11 |
+
from jaxtyping import Float
|
12 |
+
from typeguard import typechecked
|
13 |
+
|
14 |
+
|
15 |
+
@torch.no_grad()
|
16 |
+
@typechecked
|
17 |
+
def get_contributions(
|
18 |
+
parts: torch.Tensor,
|
19 |
+
whole: torch.Tensor,
|
20 |
+
distance_norm: int = 1,
|
21 |
+
) -> torch.Tensor:
|
22 |
+
"""
|
23 |
+
Compute contributions of the `parts` vectors into the `whole` vector.
|
24 |
+
|
25 |
+
Shapes of the tensors are as follows:
|
26 |
+
parts: p_1 ... p_k, v_1 ... v_n, d
|
27 |
+
whole: v_1 ... v_n, d
|
28 |
+
result: p_1 ... p_k, v_1 ... v_n
|
29 |
+
|
30 |
+
Here
|
31 |
+
* `p_1 ... p_k`: dimensions for enumerating the parts
|
32 |
+
* `v_1 ... v_n`: dimensions listing the independent cases (batching),
|
33 |
+
* `d` is the dimension to compute the distances on.
|
34 |
+
|
35 |
+
The resulting contributions will be normalized so that
|
36 |
+
for each v_: sum(over p_ of result(p_, v_)) = 1.
|
37 |
+
"""
|
38 |
+
EPS = 1e-5
|
39 |
+
|
40 |
+
k = len(parts.shape) - len(whole.shape)
|
41 |
+
assert k >= 0
|
42 |
+
assert parts.shape[k:] == whole.shape
|
43 |
+
bc_whole = whole.expand(parts.shape) # new dims p_1 ... p_k are added to the front
|
44 |
+
|
45 |
+
distance = torch.nn.functional.pairwise_distance(parts, bc_whole, p=distance_norm)
|
46 |
+
|
47 |
+
whole_norm = torch.norm(whole, p=distance_norm, dim=-1)
|
48 |
+
distance = (whole_norm - distance).clip(min=EPS)
|
49 |
+
|
50 |
+
sum = distance.sum(dim=tuple(range(k)), keepdim=True)
|
51 |
+
|
52 |
+
return distance / sum
|
53 |
+
|
54 |
+
|
55 |
+
@torch.no_grad()
|
56 |
+
@typechecked
|
57 |
+
def get_contributions_with_one_off_part(
|
58 |
+
parts: torch.Tensor,
|
59 |
+
one_off: torch.Tensor,
|
60 |
+
whole: torch.Tensor,
|
61 |
+
distance_norm: int = 1,
|
62 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
63 |
+
"""
|
64 |
+
Same as computing the contributions, but there is one additional part. That's useful
|
65 |
+
because we always have the residual stream as one of the parts.
|
66 |
+
|
67 |
+
See `get_contributions` documentation about `parts` and `whole` dimensions. The
|
68 |
+
`one_off` should have the same dimensions as `whole`.
|
69 |
+
|
70 |
+
Returns a pair consisting of
|
71 |
+
1. contributions tensor for the `parts`
|
72 |
+
2. contributions tensor for the `one_off` vector
|
73 |
+
"""
|
74 |
+
assert one_off.shape == whole.shape
|
75 |
+
|
76 |
+
k = len(parts.shape) - len(whole.shape)
|
77 |
+
assert k >= 0
|
78 |
+
|
79 |
+
# Flatten the p_ dimensions, get contributions for the list, unflatten.
|
80 |
+
flat = parts.flatten(start_dim=0, end_dim=k - 1)
|
81 |
+
flat = torch.cat([flat, one_off.unsqueeze(0)])
|
82 |
+
contributions = get_contributions(flat, whole, distance_norm)
|
83 |
+
parts_contributions, one_off_contributions = torch.split(
|
84 |
+
contributions, flat.shape[0] - 1
|
85 |
+
)
|
86 |
+
return (
|
87 |
+
parts_contributions.unflatten(0, parts.shape[0:k]),
|
88 |
+
one_off_contributions[0],
|
89 |
+
)
|
90 |
+
|
91 |
+
|
92 |
+
@torch.no_grad()
|
93 |
+
@typechecked
|
94 |
+
def get_attention_contributions(
|
95 |
+
resid_pre: Float[torch.Tensor, "batch pos d_model"],
|
96 |
+
resid_mid: Float[torch.Tensor, "batch pos d_model"],
|
97 |
+
decomposed_attn: Float[torch.Tensor, "batch pos key_pos head d_model"],
|
98 |
+
distance_norm: int = 1,
|
99 |
+
) -> Tuple[
|
100 |
+
Float[torch.Tensor, "batch pos key_pos head"],
|
101 |
+
Float[torch.Tensor, "batch pos"],
|
102 |
+
]:
|
103 |
+
"""
|
104 |
+
Returns a pair of
|
105 |
+
- a tensor of contributions of each token via each head
|
106 |
+
- the contribution of the residual stream.
|
107 |
+
"""
|
108 |
+
|
109 |
+
# part dimensions | batch dimensions | vector dimension
|
110 |
+
# ----------------+------------------+-----------------
|
111 |
+
# key_pos, head | batch, pos | d_model
|
112 |
+
parts = einops.rearrange(
|
113 |
+
decomposed_attn,
|
114 |
+
"batch pos key_pos head d_model -> key_pos head batch pos d_model",
|
115 |
+
)
|
116 |
+
attn_contribution, residual_contribution = get_contributions_with_one_off_part(
|
117 |
+
parts, resid_pre, resid_mid, distance_norm
|
118 |
+
)
|
119 |
+
return (
|
120 |
+
einops.rearrange(
|
121 |
+
attn_contribution, "key_pos head batch pos -> batch pos key_pos head"
|
122 |
+
),
|
123 |
+
residual_contribution,
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
@torch.no_grad()
|
128 |
+
@typechecked
|
129 |
+
def get_mlp_contributions(
|
130 |
+
resid_mid: Float[torch.Tensor, "batch pos d_model"],
|
131 |
+
resid_post: Float[torch.Tensor, "batch pos d_model"],
|
132 |
+
mlp_out: Float[torch.Tensor, "batch pos d_model"],
|
133 |
+
distance_norm: int = 1,
|
134 |
+
) -> Tuple[Float[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos"]]:
|
135 |
+
"""
|
136 |
+
Returns a pair of (mlp, residual) contributions for each sentence and token.
|
137 |
+
"""
|
138 |
+
|
139 |
+
contributions = get_contributions(
|
140 |
+
torch.stack((mlp_out, resid_mid)), resid_post, distance_norm
|
141 |
+
)
|
142 |
+
return contributions[0], contributions[1]
|
143 |
+
|
144 |
+
|
145 |
+
@torch.no_grad()
|
146 |
+
@typechecked
|
147 |
+
def get_decomposed_mlp_contributions(
|
148 |
+
resid_mid: Float[torch.Tensor, "d_model"],
|
149 |
+
resid_post: Float[torch.Tensor, "d_model"],
|
150 |
+
decomposed_mlp_out: Float[torch.Tensor, "hidden d_model"],
|
151 |
+
distance_norm: int = 1,
|
152 |
+
) -> Tuple[Float[torch.Tensor, "hidden"], float]:
|
153 |
+
"""
|
154 |
+
Similar to `get_mlp_contributions`, but it takes the MLP output for each neuron of
|
155 |
+
the hidden layer and thus computes a contribution per neuron.
|
156 |
+
|
157 |
+
Doesn't contain batch and token dimensions for sake of saving memory. But we may
|
158 |
+
consider adding them.
|
159 |
+
"""
|
160 |
+
|
161 |
+
neuron_contributions, residual_contribution = get_contributions_with_one_off_part(
|
162 |
+
decomposed_mlp_out, resid_mid, resid_post, distance_norm
|
163 |
+
)
|
164 |
+
return neuron_contributions, residual_contribution.item()
|
165 |
+
|
166 |
+
|
167 |
+
@torch.no_grad()
|
168 |
+
def apply_threshold_and_renormalize(
|
169 |
+
threshold: float,
|
170 |
+
c_blocks: torch.Tensor,
|
171 |
+
c_residual: torch.Tensor,
|
172 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
173 |
+
"""
|
174 |
+
Thresholding mechanism used in the original graphs paper. After the threshold is
|
175 |
+
applied, the remaining contributions are renormalized on order to sum up to 1 for
|
176 |
+
each representation.
|
177 |
+
|
178 |
+
threshold: The threshold.
|
179 |
+
c_residual: Contribution of the residual stream for each representation. This tensor
|
180 |
+
should contain 1 element per representation, i.e., its dimensions are all batch
|
181 |
+
dimensions.
|
182 |
+
c_blocks: Contributions of the blocks. Could be 1 block per representation, like
|
183 |
+
ffn, or heads*tokens blocks in case of attention. The shape of `c_residual`
|
184 |
+
must be a prefix if the shape of this tensor. The remaining dimensions are for
|
185 |
+
listing the blocks.
|
186 |
+
"""
|
187 |
+
|
188 |
+
block_dims = len(c_blocks.shape)
|
189 |
+
resid_dims = len(c_residual.shape)
|
190 |
+
bound_dims = block_dims - resid_dims
|
191 |
+
assert bound_dims >= 0
|
192 |
+
assert c_blocks.shape[0:resid_dims] == c_residual.shape
|
193 |
+
|
194 |
+
c_blocks = c_blocks * (c_blocks > threshold)
|
195 |
+
c_residual = c_residual * (c_residual > threshold)
|
196 |
+
|
197 |
+
denom = c_residual + c_blocks.sum(dim=tuple(range(resid_dims, block_dims)))
|
198 |
+
return (
|
199 |
+
c_blocks / denom.reshape(denom.shape + (1,) * bound_dims),
|
200 |
+
c_residual / denom,
|
201 |
+
)
|
llm_transparency_tool/routes/graph.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import List, Optional
|
8 |
+
|
9 |
+
import networkx as nx
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import llm_transparency_tool.routes.contributions as contributions
|
13 |
+
from llm_transparency_tool.models.transparent_llm import TransparentLlm
|
14 |
+
|
15 |
+
|
16 |
+
class GraphBuilder:
|
17 |
+
"""
|
18 |
+
Constructs the contributions graph with edges given one by one. The resulting graph
|
19 |
+
is a networkx graph that can be accessed via the `graph` field. It contains the
|
20 |
+
following types of nodes:
|
21 |
+
|
22 |
+
- X0_<token>: the original token.
|
23 |
+
- A<layer>_<token>: the residual stream after attention at the given layer for the
|
24 |
+
given token.
|
25 |
+
- M<layer>_<token>: the ffn block.
|
26 |
+
- I<layer>_<token>: the residual stream after the ffn block.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, n_layers: int, n_tokens: int):
|
30 |
+
self._n_layers = n_layers
|
31 |
+
self._n_tokens = n_tokens
|
32 |
+
|
33 |
+
self.graph = nx.DiGraph()
|
34 |
+
for layer in range(n_layers):
|
35 |
+
for token in range(n_tokens):
|
36 |
+
self.graph.add_node(f"A{layer}_{token}")
|
37 |
+
self.graph.add_node(f"I{layer}_{token}")
|
38 |
+
self.graph.add_node(f"M{layer}_{token}")
|
39 |
+
for token in range(n_tokens):
|
40 |
+
self.graph.add_node(f"X0_{token}")
|
41 |
+
|
42 |
+
def get_output_node(self, token: int):
|
43 |
+
return f"I{self._n_layers - 1}_{token}"
|
44 |
+
|
45 |
+
def _add_edge(self, u: str, v: str, weight: float):
|
46 |
+
# TODO(igortufanov): Here we sum up weights for multi-edges. It happens with
|
47 |
+
# attention from the current token and the residual edge. Ideally these need to
|
48 |
+
# be 2 separate edges, but then we need to do a MultiGraph. Multigraph is fine,
|
49 |
+
# but when we try to traverse it, we face some NetworkX issue with EDGE_OK
|
50 |
+
# receiving 3 arguments instead of 2.
|
51 |
+
if self.graph.has_edge(u, v):
|
52 |
+
self.graph[u][v]["weight"] += weight
|
53 |
+
else:
|
54 |
+
self.graph.add_edge(u, v, weight=weight)
|
55 |
+
|
56 |
+
def add_attention_edge(self, layer: int, token_from: int, token_to: int, w: float):
|
57 |
+
self._add_edge(
|
58 |
+
f"I{layer-1}_{token_from}" if layer > 0 else f"X0_{token_from}",
|
59 |
+
f"A{layer}_{token_to}",
|
60 |
+
w,
|
61 |
+
)
|
62 |
+
|
63 |
+
def add_residual_to_attn(self, layer: int, token: int, w: float):
|
64 |
+
self._add_edge(
|
65 |
+
f"I{layer-1}_{token}" if layer > 0 else f"X0_{token}",
|
66 |
+
f"A{layer}_{token}",
|
67 |
+
w,
|
68 |
+
)
|
69 |
+
|
70 |
+
def add_ffn_edge(self, layer: int, token: int, w: float):
|
71 |
+
self._add_edge(f"A{layer}_{token}", f"M{layer}_{token}", w)
|
72 |
+
self._add_edge(f"M{layer}_{token}", f"I{layer}_{token}", w)
|
73 |
+
|
74 |
+
def add_residual_to_ffn(self, layer: int, token: int, w: float):
|
75 |
+
self._add_edge(f"A{layer}_{token}", f"I{layer}_{token}", w)
|
76 |
+
|
77 |
+
|
78 |
+
@torch.no_grad()
|
79 |
+
def build_full_graph(
|
80 |
+
model: TransparentLlm,
|
81 |
+
batch_i: int = 0,
|
82 |
+
renormalizing_threshold: Optional[float] = None,
|
83 |
+
) -> nx.Graph:
|
84 |
+
"""
|
85 |
+
Build the contribution graph for all blocks of the model and all tokens.
|
86 |
+
|
87 |
+
model: The transparent llm which already did the inference.
|
88 |
+
batch_i: Which sentence to use from the batch that was given to the model.
|
89 |
+
renormalizing_threshold: If specified, will apply renormalizing thresholding to the
|
90 |
+
contributions. All contributions below the threshold will be erazed and the rest
|
91 |
+
will be renormalized.
|
92 |
+
"""
|
93 |
+
n_layers = model.model_info().n_layers
|
94 |
+
n_tokens = model.tokens()[batch_i].shape[0]
|
95 |
+
|
96 |
+
builder = GraphBuilder(n_layers, n_tokens)
|
97 |
+
|
98 |
+
for layer in range(n_layers):
|
99 |
+
c_attn, c_resid_attn = contributions.get_attention_contributions(
|
100 |
+
resid_pre=model.residual_in(layer)[batch_i].unsqueeze(0),
|
101 |
+
resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0),
|
102 |
+
decomposed_attn=model.decomposed_attn(batch_i, layer).unsqueeze(0),
|
103 |
+
)
|
104 |
+
if renormalizing_threshold is not None:
|
105 |
+
c_attn, c_resid_attn = contributions.apply_threshold_and_renormalize(
|
106 |
+
renormalizing_threshold, c_attn, c_resid_attn
|
107 |
+
)
|
108 |
+
for token_from in range(n_tokens):
|
109 |
+
for token_to in range(n_tokens):
|
110 |
+
# Sum attention contributions over heads.
|
111 |
+
c = c_attn[batch_i, token_to, token_from].sum().item()
|
112 |
+
builder.add_attention_edge(layer, token_from, token_to, c)
|
113 |
+
for token in range(n_tokens):
|
114 |
+
builder.add_residual_to_attn(
|
115 |
+
layer, token, c_resid_attn[batch_i, token].item()
|
116 |
+
)
|
117 |
+
|
118 |
+
c_ffn, c_resid_ffn = contributions.get_mlp_contributions(
|
119 |
+
resid_mid=model.residual_after_attn(layer)[batch_i].unsqueeze(0),
|
120 |
+
resid_post=model.residual_out(layer)[batch_i].unsqueeze(0),
|
121 |
+
mlp_out=model.ffn_out(layer)[batch_i].unsqueeze(0),
|
122 |
+
)
|
123 |
+
if renormalizing_threshold is not None:
|
124 |
+
c_ffn, c_resid_ffn = contributions.apply_threshold_and_renormalize(
|
125 |
+
renormalizing_threshold, c_ffn, c_resid_ffn
|
126 |
+
)
|
127 |
+
for token in range(n_tokens):
|
128 |
+
builder.add_ffn_edge(layer, token, c_ffn[batch_i, token].item())
|
129 |
+
builder.add_residual_to_ffn(
|
130 |
+
layer, token, c_resid_ffn[batch_i, token].item()
|
131 |
+
)
|
132 |
+
|
133 |
+
return builder.graph
|
134 |
+
|
135 |
+
|
136 |
+
def build_paths_to_predictions(
|
137 |
+
graph: nx.Graph,
|
138 |
+
n_layers: int,
|
139 |
+
n_tokens: int,
|
140 |
+
starting_tokens: List[int],
|
141 |
+
threshold: float,
|
142 |
+
) -> List[nx.Graph]:
|
143 |
+
"""
|
144 |
+
Given the full graph, this function returns only the trees leading to the specified
|
145 |
+
tokens. Edges with weight below `threshold` will be ignored.
|
146 |
+
"""
|
147 |
+
builder = GraphBuilder(n_layers, n_tokens)
|
148 |
+
|
149 |
+
rgraph = graph.reverse()
|
150 |
+
search_graph = nx.subgraph_view(
|
151 |
+
rgraph, filter_edge=lambda u, v: rgraph[u][v]["weight"] > threshold
|
152 |
+
)
|
153 |
+
|
154 |
+
result = []
|
155 |
+
for start in starting_tokens:
|
156 |
+
assert start < n_tokens
|
157 |
+
assert start >= 0
|
158 |
+
edges = nx.edge_dfs(search_graph, source=builder.get_output_node(start))
|
159 |
+
tree = search_graph.edge_subgraph(edges)
|
160 |
+
# Reverse the edges because the dfs was going from upper layer downwards.
|
161 |
+
result.append(tree.reverse())
|
162 |
+
|
163 |
+
return result
|
llm_transparency_tool/routes/graph_node.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from enum import Enum
|
9 |
+
from typing import List, Optional
|
10 |
+
|
11 |
+
|
12 |
+
class NodeType(Enum):
|
13 |
+
AFTER_ATTN = "after_attn"
|
14 |
+
AFTER_FFN = "after_ffn"
|
15 |
+
FFN = "ffn"
|
16 |
+
ORIGINAL = "original" # The original tokens
|
17 |
+
|
18 |
+
|
19 |
+
def _format_block_hierachy_string(blocks: List[str]) -> str:
|
20 |
+
return " βΈ ".join(blocks)
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class GraphNode:
|
25 |
+
layer: int
|
26 |
+
token: int
|
27 |
+
type: NodeType
|
28 |
+
|
29 |
+
def is_in_residual_stream(self) -> bool:
|
30 |
+
return self.type in [NodeType.AFTER_ATTN, NodeType.AFTER_FFN]
|
31 |
+
|
32 |
+
def get_residual_predecessor(self) -> Optional["GraphNode"]:
|
33 |
+
"""
|
34 |
+
Get another graph node which points to the state of the residual stream before
|
35 |
+
this node.
|
36 |
+
|
37 |
+
Retun None if current representation is the first one in the residual stream.
|
38 |
+
"""
|
39 |
+
scheme = {
|
40 |
+
NodeType.AFTER_ATTN: GraphNode(
|
41 |
+
layer=max(self.layer - 1, 0),
|
42 |
+
token=self.token,
|
43 |
+
type=NodeType.AFTER_FFN if self.layer > 0 else NodeType.ORIGINAL,
|
44 |
+
),
|
45 |
+
NodeType.AFTER_FFN: GraphNode(
|
46 |
+
layer=self.layer,
|
47 |
+
token=self.token,
|
48 |
+
type=NodeType.AFTER_ATTN,
|
49 |
+
),
|
50 |
+
NodeType.FFN: GraphNode(
|
51 |
+
layer=self.layer,
|
52 |
+
token=self.token,
|
53 |
+
type=NodeType.AFTER_ATTN,
|
54 |
+
),
|
55 |
+
NodeType.ORIGINAL: None,
|
56 |
+
}
|
57 |
+
node = scheme[self.type]
|
58 |
+
if node.layer < 0:
|
59 |
+
return None
|
60 |
+
return node
|
61 |
+
|
62 |
+
def get_name(self) -> str:
|
63 |
+
return _format_block_hierachy_string(
|
64 |
+
[f"L{self.layer}", f"T{self.token}", str(self.type.value)]
|
65 |
+
)
|
66 |
+
|
67 |
+
def get_predecessor_block_name(self) -> str:
|
68 |
+
"""
|
69 |
+
Return the name of the block standing between current node and its predecessor
|
70 |
+
in the residual stream.
|
71 |
+
"""
|
72 |
+
scheme = {
|
73 |
+
NodeType.AFTER_ATTN: [f"L{self.layer}", "attn"],
|
74 |
+
NodeType.AFTER_FFN: [f"L{self.layer}", "ffn"],
|
75 |
+
NodeType.FFN: [f"L{self.layer}", "ffn"],
|
76 |
+
NodeType.ORIGINAL: ["Nothing"],
|
77 |
+
}
|
78 |
+
return _format_block_hierachy_string(scheme[self.type])
|
79 |
+
|
80 |
+
def get_head_name(self, head: Optional[int]) -> str:
|
81 |
+
path = [f"L{self.layer}", "attn"]
|
82 |
+
if head is not None:
|
83 |
+
path.append(f"H{head}")
|
84 |
+
return _format_block_hierachy_string(path)
|
85 |
+
|
86 |
+
def get_neuron_name(self, neuron: Optional[int]) -> str:
|
87 |
+
path = [f"L{self.layer}", "ffn"]
|
88 |
+
if neuron is not None:
|
89 |
+
path.append(f"N{neuron}")
|
90 |
+
return _format_block_hierachy_string(path)
|
llm_transparency_tool/routes/test_contributions.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import unittest
|
8 |
+
from typing import Any, List
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import llm_transparency_tool.routes.contributions as contributions
|
13 |
+
|
14 |
+
|
15 |
+
class TestContributions(unittest.TestCase):
|
16 |
+
def setUp(self):
|
17 |
+
torch.manual_seed(123)
|
18 |
+
|
19 |
+
self.eps = 1e-4
|
20 |
+
|
21 |
+
# It may be useful to run the test on GPU in case there are any issues with
|
22 |
+
# creating temporary tensors on another device. But turn this off by default.
|
23 |
+
self.test_on_gpu = False
|
24 |
+
|
25 |
+
self.device = "cuda" if self.test_on_gpu else "cpu"
|
26 |
+
|
27 |
+
self.batch = 4
|
28 |
+
self.tokens = 5
|
29 |
+
self.heads = 6
|
30 |
+
self.d_model = 10
|
31 |
+
|
32 |
+
self.decomposed_attn = torch.rand(
|
33 |
+
self.batch,
|
34 |
+
self.tokens,
|
35 |
+
self.tokens,
|
36 |
+
self.heads,
|
37 |
+
self.d_model,
|
38 |
+
device=self.device,
|
39 |
+
)
|
40 |
+
self.mlp_out = torch.rand(
|
41 |
+
self.batch, self.tokens, self.d_model, device=self.device
|
42 |
+
)
|
43 |
+
self.resid_pre = torch.rand(
|
44 |
+
self.batch, self.tokens, self.d_model, device=self.device
|
45 |
+
)
|
46 |
+
self.resid_mid = torch.rand(
|
47 |
+
self.batch, self.tokens, self.d_model, device=self.device
|
48 |
+
)
|
49 |
+
self.resid_post = torch.rand(
|
50 |
+
self.batch, self.tokens, self.d_model, device=self.device
|
51 |
+
)
|
52 |
+
|
53 |
+
def _assert_tensor_eq(self, t: torch.Tensor, expected: List[Any]):
|
54 |
+
self.assertTrue(
|
55 |
+
torch.isclose(t, torch.Tensor(expected), atol=self.eps).all(),
|
56 |
+
t,
|
57 |
+
)
|
58 |
+
|
59 |
+
def test_mlp_contributions(self):
|
60 |
+
mlp_out = torch.tensor([[[1.0, 1.0]]])
|
61 |
+
resid_mid = torch.tensor([[[0.0, 0.0]]])
|
62 |
+
resid_post = torch.tensor([[[1.0, 1.0]]])
|
63 |
+
|
64 |
+
c_mlp, c_residual = contributions.get_mlp_contributions(
|
65 |
+
resid_mid, resid_post, mlp_out
|
66 |
+
)
|
67 |
+
self.assertAlmostEqual(c_mlp.item(), 1.0, delta=self.eps)
|
68 |
+
self.assertAlmostEqual(c_residual.item(), 0.0, delta=self.eps)
|
69 |
+
|
70 |
+
def test_decomposed_attn_contributions(self):
|
71 |
+
resid_pre = torch.tensor([[[2.0, 1.0]]])
|
72 |
+
resid_mid = torch.tensor([[[2.0, 2.0]]])
|
73 |
+
decomposed_attn = torch.tensor(
|
74 |
+
[
|
75 |
+
[
|
76 |
+
[
|
77 |
+
[
|
78 |
+
[1.0, 1.0],
|
79 |
+
[-1.0, 0.0],
|
80 |
+
]
|
81 |
+
]
|
82 |
+
]
|
83 |
+
]
|
84 |
+
)
|
85 |
+
|
86 |
+
c_attn, c_residual = contributions.get_attention_contributions(
|
87 |
+
resid_pre, resid_mid, decomposed_attn, distance_norm=2
|
88 |
+
)
|
89 |
+
self._assert_tensor_eq(c_attn, [[[[0.43613, 0]]]])
|
90 |
+
self.assertAlmostEqual(c_residual.item(), 0.56387, delta=self.eps)
|
91 |
+
|
92 |
+
def test_decomposed_mlp_contributions(self):
|
93 |
+
pre = torch.tensor([10.0, 10.0])
|
94 |
+
post = torch.tensor([-10.0, 10.0])
|
95 |
+
neuron_impacts = torch.tensor(
|
96 |
+
[
|
97 |
+
[0.0, 1.0],
|
98 |
+
[1.0, 0.0],
|
99 |
+
[-21.0, -1.0],
|
100 |
+
]
|
101 |
+
)
|
102 |
+
c_mlp, c_residual = contributions.get_decomposed_mlp_contributions(
|
103 |
+
pre, post, neuron_impacts, distance_norm=2
|
104 |
+
)
|
105 |
+
# A bit counter-intuitive, but the only vector pointing from 0 towards the
|
106 |
+
# output is the first one.
|
107 |
+
self._assert_tensor_eq(c_mlp, [1, 0, 0])
|
108 |
+
self.assertAlmostEqual(c_residual, 0, delta=self.eps)
|
109 |
+
|
110 |
+
def test_decomposed_mlp_contributions_single_direction(self):
|
111 |
+
pre = torch.tensor([1.0, 1.0])
|
112 |
+
post = torch.tensor([4.0, 4.0])
|
113 |
+
neuron_impacts = torch.tensor(
|
114 |
+
[
|
115 |
+
[1.0, 1.0],
|
116 |
+
[2.0, 2.0],
|
117 |
+
]
|
118 |
+
)
|
119 |
+
c_mlp, c_residual = contributions.get_decomposed_mlp_contributions(
|
120 |
+
pre, post, neuron_impacts, distance_norm=2
|
121 |
+
)
|
122 |
+
self._assert_tensor_eq(c_mlp, [0.25, 0.5])
|
123 |
+
self.assertAlmostEqual(c_residual, 0.25, delta=self.eps)
|
124 |
+
|
125 |
+
def test_attention_contributions_shape(self):
|
126 |
+
c_attn, c_residual = contributions.get_attention_contributions(
|
127 |
+
self.resid_pre, self.resid_mid, self.decomposed_attn
|
128 |
+
)
|
129 |
+
self.assertEqual(
|
130 |
+
list(c_attn.shape), [self.batch, self.tokens, self.tokens, self.heads]
|
131 |
+
)
|
132 |
+
self.assertEqual(list(c_residual.shape), [self.batch, self.tokens])
|
133 |
+
|
134 |
+
def test_mlp_contributions_shape(self):
|
135 |
+
c_mlp, c_residual = contributions.get_mlp_contributions(
|
136 |
+
self.resid_mid, self.resid_post, self.mlp_out
|
137 |
+
)
|
138 |
+
self.assertEqual(list(c_mlp.shape), [self.batch, self.tokens])
|
139 |
+
self.assertEqual(list(c_residual.shape), [self.batch, self.tokens])
|
140 |
+
|
141 |
+
def test_renormalizing_threshold(self):
|
142 |
+
c_blocks = torch.Tensor([[0.05, 0.15], [0.05, 0.05]])
|
143 |
+
c_residual = torch.Tensor([0.8, 0.9])
|
144 |
+
norm_blocks, norm_residual = contributions.apply_threshold_and_renormalize(
|
145 |
+
0.1, c_blocks, c_residual
|
146 |
+
)
|
147 |
+
self._assert_tensor_eq(norm_blocks, [[0.0, 0.157894], [0.0, 0.0]])
|
148 |
+
self._assert_tensor_eq(norm_residual, [0.842105, 1.0])
|
llm_transparency_tool/server/app.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from typing import Dict, List, Optional, Tuple
|
10 |
+
|
11 |
+
import networkx as nx
|
12 |
+
import pandas as pd
|
13 |
+
import plotly.express
|
14 |
+
import plotly.graph_objects as go
|
15 |
+
import streamlit as st
|
16 |
+
import streamlit_extras.row as st_row
|
17 |
+
import torch
|
18 |
+
from jaxtyping import Float
|
19 |
+
from torch.amp import autocast
|
20 |
+
from transformers import HfArgumentParser
|
21 |
+
|
22 |
+
import llm_transparency_tool.components
|
23 |
+
from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm
|
24 |
+
import llm_transparency_tool.routes.contributions as contributions
|
25 |
+
import llm_transparency_tool.routes.graph
|
26 |
+
from llm_transparency_tool.models.transparent_llm import TransparentLlm
|
27 |
+
from llm_transparency_tool.routes.graph_node import NodeType
|
28 |
+
from llm_transparency_tool.server.graph_selection import (
|
29 |
+
GraphSelection,
|
30 |
+
UiGraphEdge,
|
31 |
+
UiGraphNode,
|
32 |
+
)
|
33 |
+
from llm_transparency_tool.server.styles import (
|
34 |
+
RenderSettings,
|
35 |
+
logits_color_map,
|
36 |
+
margins_css,
|
37 |
+
string_to_display,
|
38 |
+
)
|
39 |
+
from llm_transparency_tool.server.utils import (
|
40 |
+
B0,
|
41 |
+
get_contribution_graph,
|
42 |
+
load_dataset,
|
43 |
+
load_model,
|
44 |
+
possible_devices,
|
45 |
+
run_model_with_session_caching,
|
46 |
+
st_placeholder,
|
47 |
+
)
|
48 |
+
from llm_transparency_tool.server.monitor import SystemMonitor
|
49 |
+
|
50 |
+
from networkx.classes.digraph import DiGraph
|
51 |
+
|
52 |
+
|
53 |
+
@st.cache_resource(
|
54 |
+
hash_funcs={
|
55 |
+
nx.Graph: id,
|
56 |
+
DiGraph: id
|
57 |
+
}
|
58 |
+
)
|
59 |
+
def cached_build_paths_to_predictions(
|
60 |
+
graph: nx.Graph,
|
61 |
+
n_layers: int,
|
62 |
+
n_tokens: int,
|
63 |
+
starting_tokens: List[int],
|
64 |
+
threshold: float,
|
65 |
+
):
|
66 |
+
return llm_transparency_tool.routes.graph.build_paths_to_predictions(
|
67 |
+
graph, n_layers, n_tokens, starting_tokens, threshold
|
68 |
+
)
|
69 |
+
|
70 |
+
@st.cache_resource(
|
71 |
+
hash_funcs={
|
72 |
+
TransformerLensTransparentLlm: id
|
73 |
+
}
|
74 |
+
)
|
75 |
+
def cached_run_inference_and_populate_state(
|
76 |
+
stateless_model,
|
77 |
+
sentences,
|
78 |
+
):
|
79 |
+
stateful_model = stateless_model.copy()
|
80 |
+
stateful_model.run(sentences)
|
81 |
+
return stateful_model
|
82 |
+
|
83 |
+
|
84 |
+
@dataclass
|
85 |
+
class LlmViewerConfig:
|
86 |
+
debug: bool = field(
|
87 |
+
default=False,
|
88 |
+
metadata={"help": "Show debugging information, like the time profile."},
|
89 |
+
)
|
90 |
+
|
91 |
+
preloaded_dataset_filename: Optional[str] = field(
|
92 |
+
default=None,
|
93 |
+
metadata={"help": "The name of the text file to load the lines from."},
|
94 |
+
)
|
95 |
+
|
96 |
+
demo_mode: bool = field(
|
97 |
+
default=False,
|
98 |
+
metadata={"help": "Whether the app should be in the demo mode."},
|
99 |
+
)
|
100 |
+
|
101 |
+
allow_loading_dataset_files: bool = field(
|
102 |
+
default=True,
|
103 |
+
metadata={"help": "Whether the app should be able to load the dataset files " "on the server side."},
|
104 |
+
)
|
105 |
+
|
106 |
+
max_user_string_length: Optional[int] = field(
|
107 |
+
default=None,
|
108 |
+
metadata={
|
109 |
+
"help": "Limit for the length of user-provided sentences (in characters), " "or None if there is no limit."
|
110 |
+
},
|
111 |
+
)
|
112 |
+
|
113 |
+
models: Dict[str, str] = field(
|
114 |
+
default_factory=dict,
|
115 |
+
metadata={
|
116 |
+
"help": "Locations of models which are stored locally. Dictionary: official "
|
117 |
+
"HuggingFace name -> path to dir. If None is specified, the model will be"
|
118 |
+
"downloaded from HuggingFace."
|
119 |
+
},
|
120 |
+
)
|
121 |
+
|
122 |
+
default_model: str = field(
|
123 |
+
default="",
|
124 |
+
metadata={"help": "The model to load once the UI is started."},
|
125 |
+
)
|
126 |
+
|
127 |
+
|
128 |
+
class App:
|
129 |
+
_stateful_model: TransparentLlm = None
|
130 |
+
render_settings = RenderSettings()
|
131 |
+
_graph: Optional[nx.Graph] = None
|
132 |
+
_contribution_threshold: float = 0.0
|
133 |
+
_renormalize_after_threshold: bool = False
|
134 |
+
_normalize_before_unembedding: bool = True
|
135 |
+
|
136 |
+
@property
|
137 |
+
def stateful_model(self) -> TransparentLlm:
|
138 |
+
return self._stateful_model
|
139 |
+
|
140 |
+
def __init__(self, config: LlmViewerConfig):
|
141 |
+
self._config = config
|
142 |
+
st.set_page_config(layout="wide")
|
143 |
+
st.markdown(margins_css, unsafe_allow_html=True)
|
144 |
+
|
145 |
+
def _get_representation(self, node: Optional[UiGraphNode]) -> Optional[Float[torch.Tensor, "d_model"]]:
|
146 |
+
if node is None:
|
147 |
+
return None
|
148 |
+
fn = {
|
149 |
+
NodeType.AFTER_ATTN: self.stateful_model.residual_after_attn,
|
150 |
+
NodeType.AFTER_FFN: self.stateful_model.residual_out,
|
151 |
+
NodeType.FFN: None,
|
152 |
+
NodeType.ORIGINAL: self.stateful_model.residual_in,
|
153 |
+
}
|
154 |
+
return fn[node.type](node.layer)[B0][node.token]
|
155 |
+
|
156 |
+
def draw_model_info(self):
|
157 |
+
info = self.stateful_model.model_info().__dict__
|
158 |
+
df = pd.DataFrame(
|
159 |
+
data=[str(x) for x in info.values()],
|
160 |
+
index=info.keys(),
|
161 |
+
columns=["Model parameter"],
|
162 |
+
)
|
163 |
+
st.dataframe(df, use_container_width=False)
|
164 |
+
|
165 |
+
def draw_dataset_selection(self) -> int:
|
166 |
+
def update_dataset(filename: Optional[str]):
|
167 |
+
dataset = load_dataset(filename) if filename is not None else []
|
168 |
+
st.session_state["dataset"] = dataset
|
169 |
+
st.session_state["dataset_file"] = filename
|
170 |
+
|
171 |
+
if "dataset" not in st.session_state:
|
172 |
+
update_dataset(self._config.preloaded_dataset_filename)
|
173 |
+
|
174 |
+
|
175 |
+
if not self._config.demo_mode:
|
176 |
+
if self._config.allow_loading_dataset_files:
|
177 |
+
row_f = st_row.row([2, 1], vertical_align="bottom")
|
178 |
+
filename = row_f.text_input("Dataset", value=st.session_state.dataset_file or "")
|
179 |
+
if row_f.button("Load"):
|
180 |
+
update_dataset(filename)
|
181 |
+
row_s = st_row.row([2, 1], vertical_align="bottom")
|
182 |
+
new_sentence = row_s.text_input("New sentence")
|
183 |
+
new_sentence_added = False
|
184 |
+
|
185 |
+
if row_s.button("Add"):
|
186 |
+
max_len = self._config.max_user_string_length
|
187 |
+
n = len(new_sentence)
|
188 |
+
if max_len is None or n <= max_len:
|
189 |
+
st.session_state.dataset.append(new_sentence)
|
190 |
+
new_sentence_added = True
|
191 |
+
st.session_state.sentence_selector = new_sentence
|
192 |
+
else:
|
193 |
+
st.warning(f"Sentence length {n} is larger than " f"the configured limit of {max_len}")
|
194 |
+
|
195 |
+
sentences = st.session_state.dataset
|
196 |
+
selection = st.selectbox(
|
197 |
+
"Sentence",
|
198 |
+
sentences,
|
199 |
+
index=len(sentences) - 1,
|
200 |
+
key="sentence_selector",
|
201 |
+
)
|
202 |
+
return selection
|
203 |
+
|
204 |
+
def _unembed(
|
205 |
+
self,
|
206 |
+
representation: torch.Tensor,
|
207 |
+
) -> torch.Tensor:
|
208 |
+
return self.stateful_model.unembed(representation, normalize=self._normalize_before_unembedding)
|
209 |
+
|
210 |
+
def draw_graph(self, contribution_threshold: float) -> Optional[GraphSelection]:
|
211 |
+
tokens = self.stateful_model.tokens()[B0]
|
212 |
+
n_tokens = tokens.shape[0]
|
213 |
+
model_info = self.stateful_model.model_info()
|
214 |
+
|
215 |
+
graphs = cached_build_paths_to_predictions(
|
216 |
+
self._graph,
|
217 |
+
model_info.n_layers,
|
218 |
+
n_tokens,
|
219 |
+
range(n_tokens),
|
220 |
+
contribution_threshold,
|
221 |
+
)
|
222 |
+
|
223 |
+
return llm_transparency_tool.components.contribution_graph(
|
224 |
+
model_info,
|
225 |
+
self.stateful_model.tokens_to_strings(tokens),
|
226 |
+
graphs,
|
227 |
+
key=f"graph_{hash(self.sentence)}",
|
228 |
+
)
|
229 |
+
|
230 |
+
def draw_token_matrix(
|
231 |
+
self,
|
232 |
+
values: Float[torch.Tensor, "t t"],
|
233 |
+
tokens: List[str],
|
234 |
+
value_name: str,
|
235 |
+
title: str,
|
236 |
+
):
|
237 |
+
assert values.shape[0] == len(tokens)
|
238 |
+
labels = {
|
239 |
+
"x": "<b>src</b>",
|
240 |
+
"y": "<b>tgt</b>",
|
241 |
+
"color": value_name,
|
242 |
+
}
|
243 |
+
|
244 |
+
captions = [f"({i}){t}" for i, t in enumerate(tokens)]
|
245 |
+
|
246 |
+
fig = plotly.express.imshow(
|
247 |
+
values.cpu(),
|
248 |
+
title=f'<b>{title}</b>',
|
249 |
+
labels=labels,
|
250 |
+
x=captions,
|
251 |
+
y=captions,
|
252 |
+
color_continuous_scale=self.render_settings.attention_color_map,
|
253 |
+
aspect="equal",
|
254 |
+
)
|
255 |
+
fig.update_layout(
|
256 |
+
autosize=True,
|
257 |
+
margin=go.layout.Margin(
|
258 |
+
l=50, # left margin
|
259 |
+
r=0, # right margin
|
260 |
+
b=100, # bottom margin
|
261 |
+
t=100, # top margin
|
262 |
+
# pad=10 # padding
|
263 |
+
)
|
264 |
+
)
|
265 |
+
fig.update_xaxes(tickmode="linear")
|
266 |
+
fig.update_yaxes(tickmode="linear")
|
267 |
+
fig.update_coloraxes(showscale=False)
|
268 |
+
|
269 |
+
st.plotly_chart(fig, use_container_width=True, theme=None)
|
270 |
+
|
271 |
+
def draw_attn_info(self, edge: UiGraphEdge, container_attention_map) -> Optional[int]:
|
272 |
+
"""
|
273 |
+
Returns: the index of the selected head.
|
274 |
+
"""
|
275 |
+
|
276 |
+
n_heads = self.stateful_model.model_info().n_heads
|
277 |
+
|
278 |
+
layer = edge.target.layer
|
279 |
+
|
280 |
+
head_contrib, _ = contributions.get_attention_contributions(
|
281 |
+
resid_pre=self.stateful_model.residual_in(layer)[B0].unsqueeze(0),
|
282 |
+
resid_mid=self.stateful_model.residual_after_attn(layer)[B0].unsqueeze(0),
|
283 |
+
decomposed_attn=self.stateful_model.decomposed_attn(B0, layer).unsqueeze(0),
|
284 |
+
)
|
285 |
+
|
286 |
+
# [batch pos key_pos head] -> [head]
|
287 |
+
flat_contrib = head_contrib[0, edge.target.token, edge.source.token, :]
|
288 |
+
assert flat_contrib.shape[0] == n_heads, f"{flat_contrib.shape} vs {n_heads}"
|
289 |
+
|
290 |
+
selected_head = llm_transparency_tool.components.selector(
|
291 |
+
items=[f"H{h}" if h >= 0 else "All" for h in range(-1, n_heads)],
|
292 |
+
indices=range(-1, n_heads),
|
293 |
+
temperatures=[sum(flat_contrib).item()] + flat_contrib.tolist(),
|
294 |
+
preselected_index=flat_contrib.argmax().item(),
|
295 |
+
key=f"head_selector_layer_{layer}" #_from_tok_{edge.source.token}_to_tok_{edge.target.token}",
|
296 |
+
)
|
297 |
+
print(f"head_selector_layer_{layer}_from_tok_{edge.source.token}_to_tok_{edge.target.token}")
|
298 |
+
if selected_head == -1 or selected_head is None:
|
299 |
+
# selected_head = None
|
300 |
+
selected_head = flat_contrib.argmax().item()
|
301 |
+
print('****\n' * 3 + f"selected_head: {selected_head}" + '\n****\n' * 3)
|
302 |
+
|
303 |
+
# Draw attention matrix and contributions for the selected head.
|
304 |
+
if selected_head is not None:
|
305 |
+
tokens = [
|
306 |
+
string_to_display(s) for s in self.stateful_model.tokens_to_strings(self.stateful_model.tokens()[B0])
|
307 |
+
]
|
308 |
+
|
309 |
+
with container_attention_map:
|
310 |
+
attn_container, contrib_container = st.columns([1, 1])
|
311 |
+
with attn_container:
|
312 |
+
attn = self.stateful_model.attention_matrix(B0, layer, selected_head)
|
313 |
+
self.draw_token_matrix(
|
314 |
+
attn,
|
315 |
+
tokens,
|
316 |
+
"attention",
|
317 |
+
f"Attention map L{layer} H{selected_head}",
|
318 |
+
)
|
319 |
+
with contrib_container:
|
320 |
+
contrib = head_contrib[B0, :, :, selected_head]
|
321 |
+
self.draw_token_matrix(
|
322 |
+
contrib,
|
323 |
+
tokens,
|
324 |
+
"contribution",
|
325 |
+
f"Contribution map L{layer} H{selected_head}",
|
326 |
+
)
|
327 |
+
|
328 |
+
return selected_head
|
329 |
+
|
330 |
+
def draw_ffn_info(self, node: UiGraphNode) -> Optional[int]:
|
331 |
+
"""
|
332 |
+
Returns: the index of the selected neuron.
|
333 |
+
"""
|
334 |
+
|
335 |
+
resid_mid = self.stateful_model.residual_after_attn(node.layer)[B0][node.token]
|
336 |
+
resid_post = self.stateful_model.residual_out(node.layer)[B0][node.token]
|
337 |
+
decomposed_ffn = self.stateful_model.decomposed_ffn_out(B0, node.layer, node.token)
|
338 |
+
c_ffn, _ = contributions.get_decomposed_mlp_contributions(resid_mid, resid_post, decomposed_ffn)
|
339 |
+
|
340 |
+
top_values, top_i = c_ffn.sort(descending=True)
|
341 |
+
n = min(self.render_settings.n_top_neurons, c_ffn.shape[0])
|
342 |
+
top_neurons = top_i[0:n].tolist()
|
343 |
+
|
344 |
+
selected_neuron = llm_transparency_tool.components.selector(
|
345 |
+
items=[f"{top_neurons[i]}" if i >= 0 else "All" for i in range(-1, n)],
|
346 |
+
indices=range(-1, n),
|
347 |
+
temperatures=[0.0] + top_values[0:n].tolist(),
|
348 |
+
preselected_index=-1,
|
349 |
+
key="neuron_selector",
|
350 |
+
)
|
351 |
+
if selected_neuron is None:
|
352 |
+
selected_neuron = -1
|
353 |
+
selected_neuron = None if selected_neuron == -1 else top_neurons[selected_neuron]
|
354 |
+
|
355 |
+
return selected_neuron
|
356 |
+
|
357 |
+
def _draw_token_table(
|
358 |
+
self,
|
359 |
+
n_top: int,
|
360 |
+
n_bottom: int,
|
361 |
+
representation: torch.Tensor,
|
362 |
+
predecessor: Optional[torch.Tensor] = None,
|
363 |
+
):
|
364 |
+
n_total = n_top + n_bottom
|
365 |
+
|
366 |
+
logits = self._unembed(representation)
|
367 |
+
n_vocab = logits.shape[0]
|
368 |
+
scores, indices = torch.topk(logits, n_top, largest=True)
|
369 |
+
positions = list(range(n_top))
|
370 |
+
|
371 |
+
if n_bottom > 0:
|
372 |
+
low_scores, low_indices = torch.topk(logits, n_bottom, largest=False)
|
373 |
+
indices = torch.cat((indices, low_indices.flip(0)))
|
374 |
+
scores = torch.cat((scores, low_scores.flip(0)))
|
375 |
+
positions += range(n_vocab - n_bottom, n_vocab)
|
376 |
+
|
377 |
+
tokens = [string_to_display(w) for w in self.stateful_model.tokens_to_strings(indices)]
|
378 |
+
|
379 |
+
if predecessor is not None:
|
380 |
+
pre_logits = self._unembed(predecessor)
|
381 |
+
_, sorted_pre_indices = pre_logits.sort(descending=True)
|
382 |
+
pre_indices_dict = {index: pos for pos, index in enumerate(sorted_pre_indices.tolist())}
|
383 |
+
old_positions = [pre_indices_dict[i] for i in indices.tolist()]
|
384 |
+
|
385 |
+
def pos_gain_string(pos, old_pos):
|
386 |
+
if pos == old_pos:
|
387 |
+
return ""
|
388 |
+
sign = "β" if pos > old_pos else "β"
|
389 |
+
return f"({sign}{abs(pos - old_pos)})"
|
390 |
+
|
391 |
+
position_strings = [f"{i} {pos_gain_string(i, old_i)}" for (i, old_i) in zip(positions, old_positions)]
|
392 |
+
else:
|
393 |
+
position_strings = [str(pos) for pos in positions]
|
394 |
+
|
395 |
+
def pos_gain_color(s):
|
396 |
+
color = "black"
|
397 |
+
if isinstance(s, str):
|
398 |
+
if "β" in s:
|
399 |
+
color = "red"
|
400 |
+
if "β" in s:
|
401 |
+
color = "green"
|
402 |
+
return f"color: {color}"
|
403 |
+
|
404 |
+
top_df = pd.DataFrame(
|
405 |
+
data=zip(position_strings, tokens, scores.tolist()),
|
406 |
+
columns=["Pos", "Token", "Score"],
|
407 |
+
)
|
408 |
+
|
409 |
+
st.dataframe(
|
410 |
+
top_df.style.map(pos_gain_color)
|
411 |
+
.background_gradient(
|
412 |
+
axis=0,
|
413 |
+
cmap=logits_color_map(positive_and_negative=n_bottom > 0),
|
414 |
+
)
|
415 |
+
.format(precision=3),
|
416 |
+
hide_index=True,
|
417 |
+
height=self.render_settings.table_cell_height * (n_total + 1),
|
418 |
+
use_container_width=True,
|
419 |
+
)
|
420 |
+
|
421 |
+
def draw_token_dynamics(self, representation: torch.Tensor, block_name: str) -> None:
|
422 |
+
st.caption(block_name)
|
423 |
+
self._draw_token_table(
|
424 |
+
self.render_settings.n_promoted_tokens,
|
425 |
+
self.render_settings.n_suppressed_tokens,
|
426 |
+
representation,
|
427 |
+
None,
|
428 |
+
)
|
429 |
+
|
430 |
+
def draw_top_tokens(
|
431 |
+
self,
|
432 |
+
node: UiGraphNode,
|
433 |
+
container_top_tokens,
|
434 |
+
container_token_dynamics,
|
435 |
+
) -> None:
|
436 |
+
pre_node = node.get_residual_predecessor()
|
437 |
+
if pre_node is None:
|
438 |
+
return
|
439 |
+
|
440 |
+
representation = self._get_representation(node)
|
441 |
+
predecessor = self._get_representation(pre_node)
|
442 |
+
|
443 |
+
with container_top_tokens:
|
444 |
+
st.caption(node.get_name())
|
445 |
+
self._draw_token_table(
|
446 |
+
self.render_settings.n_top_tokens,
|
447 |
+
0,
|
448 |
+
representation,
|
449 |
+
predecessor,
|
450 |
+
)
|
451 |
+
if container_token_dynamics is not None:
|
452 |
+
with container_token_dynamics:
|
453 |
+
self.draw_token_dynamics(representation - predecessor, node.get_predecessor_block_name())
|
454 |
+
|
455 |
+
def draw_attention_dynamics(self, node: UiGraphNode, head: Optional[int]):
|
456 |
+
block_name = node.get_head_name(head)
|
457 |
+
block_output = (
|
458 |
+
self.stateful_model.attention_output_per_head(B0, node.layer, node.token, head)
|
459 |
+
if head is not None
|
460 |
+
else self.stateful_model.attention_output(B0, node.layer, node.token)
|
461 |
+
)
|
462 |
+
self.draw_token_dynamics(block_output, block_name)
|
463 |
+
|
464 |
+
def draw_ffn_dynamics(self, node: UiGraphNode, neuron: Optional[int]):
|
465 |
+
block_name = node.get_neuron_name(neuron)
|
466 |
+
block_output = (
|
467 |
+
self.stateful_model.neuron_output(node.layer, neuron)
|
468 |
+
if neuron is not None
|
469 |
+
else self.stateful_model.ffn_out(node.layer)[B0][node.token]
|
470 |
+
)
|
471 |
+
self.draw_token_dynamics(block_output, block_name)
|
472 |
+
|
473 |
+
def draw_precision_controls(self, device: str) -> Tuple[torch.dtype, bool]:
|
474 |
+
"""
|
475 |
+
Draw fp16/fp32 switch and AMP control.
|
476 |
+
|
477 |
+
return: The selected precision and whether AMP should be enabled.
|
478 |
+
"""
|
479 |
+
|
480 |
+
if device == "cpu":
|
481 |
+
dtype = torch.float32
|
482 |
+
else:
|
483 |
+
dtype = st.selectbox(
|
484 |
+
"Precision",
|
485 |
+
[torch.float16, torch.bfloat16, torch.float32],
|
486 |
+
index=0,
|
487 |
+
)
|
488 |
+
|
489 |
+
amp_enabled = dtype != torch.float32
|
490 |
+
|
491 |
+
return dtype, amp_enabled
|
492 |
+
|
493 |
+
def draw_controls(self):
|
494 |
+
# model_container, data_container = st.columns([1, 1])
|
495 |
+
with st.sidebar.expander("Model", expanded=True):
|
496 |
+
list_of_devices = possible_devices()
|
497 |
+
if len(list_of_devices) > 1:
|
498 |
+
self.device = st.selectbox(
|
499 |
+
"Device",
|
500 |
+
possible_devices(),
|
501 |
+
index=0,
|
502 |
+
)
|
503 |
+
else:
|
504 |
+
self.device = list_of_devices[0]
|
505 |
+
|
506 |
+
self.dtype, self.amp_enabled = self.draw_precision_controls(self.device)
|
507 |
+
|
508 |
+
model_list = list(self._config.models)
|
509 |
+
default_choice = model_list.index(self._config.default_model)
|
510 |
+
|
511 |
+
self.model_name = st.selectbox(
|
512 |
+
"Model",
|
513 |
+
model_list,
|
514 |
+
index=default_choice,
|
515 |
+
)
|
516 |
+
|
517 |
+
if self.model_name:
|
518 |
+
self._stateful_model = load_model(
|
519 |
+
model_name=self.model_name,
|
520 |
+
_model_path=self._config.models[self.model_name],
|
521 |
+
_device=self.device,
|
522 |
+
_dtype=self.dtype,
|
523 |
+
)
|
524 |
+
self.model_key = self.model_name # TODO maybe something else?
|
525 |
+
self.draw_model_info()
|
526 |
+
|
527 |
+
self.sentence = self.draw_dataset_selection()
|
528 |
+
|
529 |
+
with st.sidebar.expander("Graph", expanded=True):
|
530 |
+
self._contribution_threshold = st.slider(
|
531 |
+
min_value=0.01,
|
532 |
+
max_value=0.1,
|
533 |
+
step=0.01,
|
534 |
+
value=0.04,
|
535 |
+
format=r"%.3f",
|
536 |
+
label="Contribution threshold",
|
537 |
+
)
|
538 |
+
self._renormalize_after_threshold = st.checkbox("Renormalize after threshold", value=True)
|
539 |
+
self._normalize_before_unembedding = st.checkbox("Normalize before unembedding", value=True)
|
540 |
+
|
541 |
+
def run_inference(self):
|
542 |
+
|
543 |
+
with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype):
|
544 |
+
self._stateful_model = cached_run_inference_and_populate_state(self.stateful_model, [self.sentence])
|
545 |
+
|
546 |
+
with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype):
|
547 |
+
self._graph = get_contribution_graph(
|
548 |
+
self.stateful_model,
|
549 |
+
self.model_key,
|
550 |
+
self.stateful_model.tokens()[B0].tolist(),
|
551 |
+
(self._contribution_threshold if self._renormalize_after_threshold else 0.0),
|
552 |
+
)
|
553 |
+
|
554 |
+
def draw_graph_and_selection(
|
555 |
+
self,
|
556 |
+
) -> None:
|
557 |
+
(
|
558 |
+
container_graph,
|
559 |
+
container_tokens,
|
560 |
+
) = st.columns(self.render_settings.column_proportions)
|
561 |
+
|
562 |
+
container_graph_left, container_graph_right = container_graph.columns([5, 1])
|
563 |
+
|
564 |
+
container_graph_left.write('##### Graph')
|
565 |
+
heads_placeholder = container_graph_right.empty()
|
566 |
+
heads_placeholder.write('##### Blocks')
|
567 |
+
container_graph_right_used = False
|
568 |
+
|
569 |
+
container_top_tokens, container_token_dynamics = container_tokens.columns([1, 1])
|
570 |
+
container_top_tokens.write('##### Top Tokens')
|
571 |
+
container_top_tokens_used = False
|
572 |
+
container_token_dynamics.write('##### Promoted Tokens')
|
573 |
+
container_token_dynamics_used = False
|
574 |
+
|
575 |
+
try:
|
576 |
+
|
577 |
+
if self.sentence is None:
|
578 |
+
return
|
579 |
+
|
580 |
+
with container_graph_left:
|
581 |
+
selection = self.draw_graph(self._contribution_threshold if not self._renormalize_after_threshold else 0.0)
|
582 |
+
|
583 |
+
if selection is None:
|
584 |
+
return
|
585 |
+
|
586 |
+
node = selection.node
|
587 |
+
edge = selection.edge
|
588 |
+
|
589 |
+
if edge is not None and edge.target.type == NodeType.AFTER_ATTN:
|
590 |
+
with container_graph_right:
|
591 |
+
container_graph_right_used = True
|
592 |
+
heads_placeholder.write('##### Heads')
|
593 |
+
head = self.draw_attn_info(edge, container_graph)
|
594 |
+
with container_token_dynamics:
|
595 |
+
self.draw_attention_dynamics(edge.target, head)
|
596 |
+
container_token_dynamics_used = True
|
597 |
+
elif node is not None and node.type == NodeType.FFN:
|
598 |
+
with container_graph_right:
|
599 |
+
container_graph_right_used = True
|
600 |
+
heads_placeholder.write('##### Neurons')
|
601 |
+
neuron = self.draw_ffn_info(node)
|
602 |
+
with container_token_dynamics:
|
603 |
+
self.draw_ffn_dynamics(node, neuron)
|
604 |
+
container_token_dynamics_used = True
|
605 |
+
|
606 |
+
if node is not None and node.is_in_residual_stream():
|
607 |
+
self.draw_top_tokens(
|
608 |
+
node,
|
609 |
+
container_top_tokens,
|
610 |
+
container_token_dynamics if not container_token_dynamics_used else None,
|
611 |
+
)
|
612 |
+
container_top_tokens_used = True
|
613 |
+
container_token_dynamics_used = True
|
614 |
+
finally:
|
615 |
+
if not container_graph_right_used:
|
616 |
+
st_placeholder('Click on an edge to see head contributions. \n\n'
|
617 |
+
'Or click on FFN to see individual neuron contributions.', container_graph_right, height=1100)
|
618 |
+
if not container_top_tokens_used:
|
619 |
+
st_placeholder('Select a node from residual stream to see its top tokens.', container_top_tokens, height=1100)
|
620 |
+
if not container_token_dynamics_used:
|
621 |
+
st_placeholder('Select a node to see its promoted tokens.', container_token_dynamics, height=1100)
|
622 |
+
|
623 |
+
|
624 |
+
def run(self):
|
625 |
+
|
626 |
+
with st.sidebar.expander("About", expanded=True):
|
627 |
+
if self._config.demo_mode:
|
628 |
+
st.caption("""
|
629 |
+
The app is deployed in Demo Mode, thus only predefined models and inputs are available.\n
|
630 |
+
You can still install the app locally and use your own models and inputs.\n
|
631 |
+
See https://github.com/facebookresearch/llm-transparency-tool for more information.
|
632 |
+
""")
|
633 |
+
|
634 |
+
self.draw_controls()
|
635 |
+
|
636 |
+
if not self.model_name:
|
637 |
+
st.warning("No model selected")
|
638 |
+
st.stop()
|
639 |
+
|
640 |
+
if self.sentence is None:
|
641 |
+
st.warning("No sentence selected")
|
642 |
+
else:
|
643 |
+
with torch.inference_mode():
|
644 |
+
self.run_inference()
|
645 |
+
|
646 |
+
self.draw_graph_and_selection()
|
647 |
+
|
648 |
+
|
649 |
+
if __name__ == "__main__":
|
650 |
+
top_parser = argparse.ArgumentParser()
|
651 |
+
top_parser.add_argument("config_file")
|
652 |
+
args = top_parser.parse_args()
|
653 |
+
|
654 |
+
parser = HfArgumentParser([LlmViewerConfig])
|
655 |
+
config = parser.parse_json_file(args.config_file)[0]
|
656 |
+
|
657 |
+
with SystemMonitor(config.debug) as prof:
|
658 |
+
app = App(config)
|
659 |
+
app.run()
|
llm_transparency_tool/server/graph_selection.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Any, Dict, Optional
|
9 |
+
|
10 |
+
from llm_transparency_tool.routes.graph_node import GraphNode, NodeType
|
11 |
+
|
12 |
+
|
13 |
+
class UiGraphNode(GraphNode):
|
14 |
+
@staticmethod
|
15 |
+
def from_json(json: Dict[str, Any]) -> Optional["UiGraphNode"]:
|
16 |
+
try:
|
17 |
+
layer = json["cell"]["layer"]
|
18 |
+
token = json["cell"]["token"]
|
19 |
+
type = NodeType(json["item"])
|
20 |
+
return UiGraphNode(layer, token, type)
|
21 |
+
except (TypeError, KeyError):
|
22 |
+
return None
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class UiGraphEdge:
|
27 |
+
source: UiGraphNode
|
28 |
+
target: UiGraphNode
|
29 |
+
weight: float
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def from_json(json: Dict[str, Any]) -> Optional["UiGraphEdge"]:
|
33 |
+
try:
|
34 |
+
source = UiGraphNode.from_json(json["from"])
|
35 |
+
target = UiGraphNode.from_json(json["to"])
|
36 |
+
if source is None or target is None:
|
37 |
+
return None
|
38 |
+
weight = float(json["weight"])
|
39 |
+
return UiGraphEdge(source, target, weight)
|
40 |
+
except (TypeError, KeyError):
|
41 |
+
return None
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class GraphSelection:
|
46 |
+
node: Optional[UiGraphNode]
|
47 |
+
edge: Optional[UiGraphEdge]
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def from_json(json: Dict[str, Any]) -> Optional["GraphSelection"]:
|
51 |
+
try:
|
52 |
+
node = UiGraphNode.from_json(json["node"])
|
53 |
+
edge = UiGraphEdge.from_json(json["edge"])
|
54 |
+
return GraphSelection(node, edge)
|
55 |
+
except (TypeError, KeyError):
|
56 |
+
return None
|
llm_transparency_tool/server/monitor.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import streamlit as st
|
9 |
+
from pyinstrument import Profiler
|
10 |
+
from typing import Dict
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
|
14 |
+
@st.cache_resource(max_entries=1, show_spinner=False)
|
15 |
+
def init_gpu_memory():
|
16 |
+
"""
|
17 |
+
When CUDA is initialized, it occupies some memory on the GPU thus this overhead
|
18 |
+
can sometimes make it difficult to understand how much memory is actually used by
|
19 |
+
the model.
|
20 |
+
|
21 |
+
This function is used to initialize CUDA and measure the overhead.
|
22 |
+
"""
|
23 |
+
if not torch.cuda.is_available():
|
24 |
+
return {}
|
25 |
+
|
26 |
+
# lets init torch gpu for a moment
|
27 |
+
gpu_memory_overhead = {}
|
28 |
+
for i in range(torch.cuda.device_count()):
|
29 |
+
torch.ones(1).cuda(i)
|
30 |
+
free, total = torch.cuda.mem_get_info(i)
|
31 |
+
occupied = total - free
|
32 |
+
gpu_memory_overhead[i] = occupied
|
33 |
+
|
34 |
+
return gpu_memory_overhead
|
35 |
+
|
36 |
+
|
37 |
+
class SystemMonitor:
|
38 |
+
"""
|
39 |
+
This class is used to monitor the system resources such as GPU memory and CPU
|
40 |
+
usage. It uses the pyinstrument library to profile the code and measure the
|
41 |
+
execution time of different parts of the code.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
enabled: bool = False,
|
47 |
+
):
|
48 |
+
self.enabled = enabled
|
49 |
+
self.profiler = Profiler()
|
50 |
+
self.overhead: Dict[int, int]
|
51 |
+
|
52 |
+
def __enter__(self):
|
53 |
+
if not self.enabled:
|
54 |
+
return
|
55 |
+
|
56 |
+
self.overhead = init_gpu_memory()
|
57 |
+
|
58 |
+
self.profiler.__enter__()
|
59 |
+
|
60 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
61 |
+
if not self.enabled:
|
62 |
+
return
|
63 |
+
|
64 |
+
self.profiler.__exit__(exc_type, exc_value, traceback)
|
65 |
+
|
66 |
+
self.report_gpu_usage()
|
67 |
+
self.report_profiler()
|
68 |
+
|
69 |
+
with st.expander("Session state"):
|
70 |
+
st.write(st.session_state)
|
71 |
+
|
72 |
+
return None
|
73 |
+
|
74 |
+
def report_gpu_usage(self):
|
75 |
+
|
76 |
+
if not torch.cuda.is_available():
|
77 |
+
return
|
78 |
+
|
79 |
+
data = []
|
80 |
+
|
81 |
+
for i in range(torch.cuda.device_count()):
|
82 |
+
free, total = torch.cuda.mem_get_info(i)
|
83 |
+
occupied = total - free
|
84 |
+
data.append({
|
85 |
+
'overhead': self.overhead[i],
|
86 |
+
'occupied': occupied - self.overhead[i],
|
87 |
+
'free': free,
|
88 |
+
})
|
89 |
+
df = pd.DataFrame(data, columns=["overhead", "occupied", "free"])
|
90 |
+
|
91 |
+
with st.sidebar.expander("System"):
|
92 |
+
st.write("GPU memory on server")
|
93 |
+
df /= 1024 ** 3 # Convert to GB
|
94 |
+
st.bar_chart(df, width=200, height=200, color=["#fefefe", "#84c9ff", "#fe2b2b"])
|
95 |
+
|
96 |
+
def report_profiler(self):
|
97 |
+
html_code = self.profiler.output_html()
|
98 |
+
with st.expander("Profiler", expanded=False):
|
99 |
+
st.components.v1.html(html_code, height=1000, scrolling=True)
|
llm_transparency_tool/server/styles.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
import matplotlib
|
10 |
+
|
11 |
+
# Unofficial way do make the padding a bit smaller.
|
12 |
+
margins_css = """
|
13 |
+
<style>
|
14 |
+
.main > div {
|
15 |
+
padding: 1rem;
|
16 |
+
padding-top: 2rem; # Still need this gap for the top bar
|
17 |
+
gap: 0rem;
|
18 |
+
}
|
19 |
+
|
20 |
+
section[data-testid="stSidebar"] {
|
21 |
+
width: 300px !important; # Set the width to your desired value
|
22 |
+
}
|
23 |
+
</style>
|
24 |
+
"""
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class RenderSettings:
|
29 |
+
column_proportions = [50, 30]
|
30 |
+
|
31 |
+
# We don't know the actual height. This will be used in order to compute the table
|
32 |
+
# viewport height when needed.
|
33 |
+
table_cell_height = 36
|
34 |
+
|
35 |
+
n_top_tokens = 30
|
36 |
+
n_promoted_tokens = 15
|
37 |
+
n_suppressed_tokens = 15
|
38 |
+
|
39 |
+
n_top_neurons = 20
|
40 |
+
|
41 |
+
attention_color_map = "Blues"
|
42 |
+
|
43 |
+
no_model_alt_text = "<no model selected>"
|
44 |
+
|
45 |
+
|
46 |
+
def string_to_display(s: str) -> str:
|
47 |
+
return s.replace(" ", "Β·")
|
48 |
+
|
49 |
+
|
50 |
+
def logits_color_map(positive_and_negative: bool) -> matplotlib.colors.Colormap:
|
51 |
+
background_colors = {
|
52 |
+
"red": [
|
53 |
+
[0.0, 0.40, 0.40],
|
54 |
+
[0.1, 0.69, 0.69],
|
55 |
+
[0.2, 0.83, 0.83],
|
56 |
+
[0.3, 0.95, 0.95],
|
57 |
+
[0.4, 0.99, 0.99],
|
58 |
+
[0.5, 1.0, 1.0],
|
59 |
+
[0.6, 0.90, 0.90],
|
60 |
+
[0.7, 0.72, 0.72],
|
61 |
+
[0.8, 0.49, 0.49],
|
62 |
+
[0.9, 0.30, 0.30],
|
63 |
+
[1.0, 0.15, 0.15],
|
64 |
+
],
|
65 |
+
"green": [
|
66 |
+
[0.0, 0.0, 0.0],
|
67 |
+
[0.1, 0.09, 0.09],
|
68 |
+
[0.2, 0.37, 0.37],
|
69 |
+
[0.3, 0.64, 0.64],
|
70 |
+
[0.4, 0.85, 0.85],
|
71 |
+
[0.5, 1.0, 1.0],
|
72 |
+
[0.6, 0.96, 0.96],
|
73 |
+
[0.7, 0.88, 0.88],
|
74 |
+
[0.8, 0.73, 0.73],
|
75 |
+
[0.9, 0.57, 0.57],
|
76 |
+
[1.0, 0.39, 0.39],
|
77 |
+
],
|
78 |
+
"blue": [
|
79 |
+
[0.0, 0.12, 0.12],
|
80 |
+
[0.1, 0.16, 0.16],
|
81 |
+
[0.2, 0.30, 0.30],
|
82 |
+
[0.3, 0.50, 0.50],
|
83 |
+
[0.4, 0.78, 0.78],
|
84 |
+
[0.5, 1.0, 1.0],
|
85 |
+
[0.6, 0.81, 0.81],
|
86 |
+
[0.7, 0.52, 0.52],
|
87 |
+
[0.8, 0.25, 0.25],
|
88 |
+
[0.9, 0.12, 0.12],
|
89 |
+
[1.0, 0.09, 0.09],
|
90 |
+
],
|
91 |
+
}
|
92 |
+
|
93 |
+
if not positive_and_negative:
|
94 |
+
# Stretch the top part to the whole range
|
95 |
+
new_colors = {}
|
96 |
+
for channel, colors in background_colors.items():
|
97 |
+
new_colors[channel] = [
|
98 |
+
[(value - 0.5) * 2, color, color]
|
99 |
+
for value, color, _ in colors
|
100 |
+
if value >= 0.5
|
101 |
+
]
|
102 |
+
background_colors = new_colors
|
103 |
+
|
104 |
+
return matplotlib.colors.LinearSegmentedColormap(
|
105 |
+
f"RdYG-{positive_and_negative}",
|
106 |
+
background_colors,
|
107 |
+
)
|
llm_transparency_tool/server/utils.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import uuid
|
8 |
+
from typing import List, Optional, Tuple
|
9 |
+
|
10 |
+
import networkx as nx
|
11 |
+
import streamlit as st
|
12 |
+
import torch
|
13 |
+
import transformers
|
14 |
+
|
15 |
+
import llm_transparency_tool.routes.graph
|
16 |
+
from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm
|
17 |
+
from llm_transparency_tool.models.transparent_llm import TransparentLlm
|
18 |
+
|
19 |
+
GPU = "gpu"
|
20 |
+
CPU = "cpu"
|
21 |
+
|
22 |
+
# This variable is for expressing the idea that batch_id = 0, but make it more
|
23 |
+
# readable than just 0.
|
24 |
+
B0 = 0
|
25 |
+
|
26 |
+
|
27 |
+
def possible_devices() -> List[str]:
|
28 |
+
devices = []
|
29 |
+
if torch.cuda.is_available():
|
30 |
+
devices.append("gpu")
|
31 |
+
devices.append("cpu")
|
32 |
+
return devices
|
33 |
+
|
34 |
+
|
35 |
+
def load_dataset(filename) -> List[str]:
|
36 |
+
with open(filename) as f:
|
37 |
+
dataset = [s.strip("\n") for s in f.readlines()]
|
38 |
+
print(f"Loaded {len(dataset)} sentences from {filename}")
|
39 |
+
return dataset
|
40 |
+
|
41 |
+
|
42 |
+
@st.cache_resource(
|
43 |
+
hash_funcs={
|
44 |
+
TransformerLensTransparentLlm: id
|
45 |
+
}
|
46 |
+
)
|
47 |
+
def load_model(
|
48 |
+
model_name: str,
|
49 |
+
_device: str,
|
50 |
+
_model_path: Optional[str] = None,
|
51 |
+
_dtype: torch.dtype = torch.float32,
|
52 |
+
) -> TransparentLlm:
|
53 |
+
"""
|
54 |
+
Returns the loaded model along with its key. The key is just a unique string which
|
55 |
+
can be used later to identify if the model has changed.
|
56 |
+
"""
|
57 |
+
assert _device in possible_devices()
|
58 |
+
|
59 |
+
causal_lm = None
|
60 |
+
tokenizer = None
|
61 |
+
|
62 |
+
tl_lm = TransformerLensTransparentLlm(
|
63 |
+
model_name=model_name,
|
64 |
+
hf_model=causal_lm,
|
65 |
+
tokenizer=tokenizer,
|
66 |
+
device=_device,
|
67 |
+
dtype=_dtype,
|
68 |
+
)
|
69 |
+
|
70 |
+
return tl_lm
|
71 |
+
|
72 |
+
|
73 |
+
def run_model(model: TransparentLlm, sentence: str) -> None:
|
74 |
+
print(f"Running inference for '{sentence}'")
|
75 |
+
model.run([sentence])
|
76 |
+
|
77 |
+
|
78 |
+
def load_model_with_session_caching(
|
79 |
+
**kwargs,
|
80 |
+
) -> Tuple[TransparentLlm, str]:
|
81 |
+
return load_model(**kwargs)
|
82 |
+
|
83 |
+
def run_model_with_session_caching(
|
84 |
+
_model: TransparentLlm,
|
85 |
+
model_key: str,
|
86 |
+
sentence: str,
|
87 |
+
):
|
88 |
+
LAST_RUN_MODEL_KEY = "last_run_model_key"
|
89 |
+
LAST_RUN_SENTENCE = "last_run_sentence"
|
90 |
+
state = st.session_state
|
91 |
+
|
92 |
+
if (
|
93 |
+
state.get(LAST_RUN_MODEL_KEY, None) == model_key
|
94 |
+
and state.get(LAST_RUN_SENTENCE, None) == sentence
|
95 |
+
):
|
96 |
+
return
|
97 |
+
|
98 |
+
run_model(_model, sentence)
|
99 |
+
state[LAST_RUN_MODEL_KEY] = model_key
|
100 |
+
state[LAST_RUN_SENTENCE] = sentence
|
101 |
+
|
102 |
+
|
103 |
+
@st.cache_resource(
|
104 |
+
hash_funcs={
|
105 |
+
TransformerLensTransparentLlm: id
|
106 |
+
}
|
107 |
+
)
|
108 |
+
def get_contribution_graph(
|
109 |
+
model: TransparentLlm, # TODO bug here
|
110 |
+
model_key: str,
|
111 |
+
tokens: List[str],
|
112 |
+
threshold: float,
|
113 |
+
) -> nx.Graph:
|
114 |
+
"""
|
115 |
+
The `model_key` and `tokens` are used only for caching. The model itself is not
|
116 |
+
hashed, hence the `_` in the beginning.
|
117 |
+
"""
|
118 |
+
return llm_transparency_tool.routes.graph.build_full_graph(
|
119 |
+
model,
|
120 |
+
B0,
|
121 |
+
threshold,
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
def st_placeholder(
|
126 |
+
text: str,
|
127 |
+
container=st,
|
128 |
+
border: bool = True,
|
129 |
+
height: Optional[int] = 500,
|
130 |
+
):
|
131 |
+
empty = container.empty()
|
132 |
+
empty.container(border=border, height=height).write(f'<small>{text}</small>', unsafe_allow_html=True)
|
133 |
+
return empty
|
pyproject.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[tool.black]
|
2 |
+
line-length = 120
|
sample_input.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
The war lasted from the year 1732 to the year 17
|
2 |
+
5 + 4 = 9, 2 + 3 =
|
3 |
+
When Mary and John went to the store, John gave a drink to
|
setup.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from setuptools import setup
|
8 |
+
|
9 |
+
setup(
|
10 |
+
name="llm_transparency_tool",
|
11 |
+
version="0.1",
|
12 |
+
packages=["llm_transparency_tool"],
|
13 |
+
)
|