{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "070a4097-7a17-409f-af5d-3d0cf43926ca", "metadata": {}, "outputs": [], "source": [ "from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM\n", "from huggingface_hub import list_repo_refs\n", "from transformers import AutoTokenizer, AutoModelForCausalLM" ] }, { "cell_type": "code", "execution_count": 4, "id": "100ec138-f7c1-4d8f-b7e0-eb715f320fdc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"mnoukhov/pythia410m-tldr-sft\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "dbc9a2db-2c16-4e8f-bd2a-213ddc5d139d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.add_special_tokens({\"pad_token\": \"<|padding|>\"}) " ] }, { "cell_type": "code", "execution_count": 16, "id": "03788af8-6733-492f-84e3-fd358bb88ffd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.pad_token_id" ] }, { "cell_type": "code", "execution_count": 12, "id": "576d3fda-7902-43d7-b4b1-3054f6192b11", "metadata": {}, "outputs": [], "source": [ "example_text = \"hello my name is mr hello\"" ] }, { "cell_type": "code", "execution_count": 24, "id": "c73ddb0c-1551-4b12-82d8-26d3742d6f57", "metadata": {}, "outputs": [], "source": [ "toks = tokenizer(example_text + tokenizer.eos_token, padding=\"max_length\", max_length=7, truncation=True)" ] }, { "cell_type": "code", "execution_count": 25, "id": "8904af15-4d27-4718-b53a-060ae65173a9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input_ids': [25521, 619, 1416, 310, 278, 83, 23120], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "toks" ] }, { "cell_type": "code", "execution_count": 26, "id": "8fcf7c83-e8df-457b-9eab-1b1ed2145a76", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "7" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum(toks['attention_mask'])" ] }, { "cell_type": "code", "execution_count": 2, "id": "ef1dddf6-1d26-4950-910a-c40b2cc394c6", "metadata": {}, "outputs": [], "source": [ "base_model_name = \"vwxyzjn/EleutherAI_pythia-1b-deduped__sft__tldr\"\n", "base_model_revision = \"sft__55513__1706646024\"" ] }, { "cell_type": "code", "execution_count": 35, "id": "bb0df32c-9d90-4ab0-a87d-0ff6ecab03b6", "metadata": {}, "outputs": [], "source": [ "model_path = \"/home/toolkit/trl_results/mnoukhov/EleutherAI_pythia-1b-deduped__sft__tldr_dpo_costa_1b_fp16.yml_3d94f50_b9ff2_merged/main\"" ] }, { "cell_type": "code", "execution_count": 36, "id": "3ae77b2a-3132-4dd1-903b-35f28b7e7e5f", "metadata": {}, "outputs": [], "source": [ "base_model = AutoModelForCausalLM.from_pretrained(model_path)" ] }, { "cell_type": "code", "execution_count": 37, "id": "08c1d05d-44a4-4859-9d54-48e7a3cd1da7", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "12749e76749a40469d7732dc23e0f1dc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors: 0%| | 0.00/4.05G [00:00')" ] }, { "cell_type": "code", "execution_count": 5, "id": "42b8260f-19a7-42e1-b809-a24deff3699c", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "032ad7febe1b4eb9899d22e5d44d23a0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading readme: 0%| | 0.00/456 [00:00" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAekAAAHpCAYAAACmzsSXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAz8UlEQVR4nO3deXhV1b3/8U+mE3IgI5gEymAAi6CgBRXTwSKkBEqrXrnPNZYKrYDFG6hAq5Rbq5b2Xvg5gAgRbCvEai3qfRRbsCCEQSwBMSUVUHkcsHgrCZWckzBlXr8/6NnNycAQTrJXct6v5zmPOXuv7Hz3Jp5P9t5rrxVhjDECAADWiXS7AAAA0DxCGgAASxHSAABYipAGAMBShDQAAJYipAEAsBQhDQCApQjp82CMUUVFhXikHADQngjp83D8+HElJibq+PHjbpcCAAgjhDQAAJYipAEAsBQhDQCApQhpAAAsRUgDAGApQhoAAEsR0gAAWIqQBgDAUoQ0AACWIqQBALAUIQ0AgKUIaQAALEVIAwBgKUIaAABLEdIAAFiKkAYAwFKENAAAliKkAQCwFCENAIClCOkOxhijsrIyGWPcLgUA0MYI6Q7G5/Mp59G18vl8bpcCAGhjhHQHFOONd7sEAEA7IKRdxuVrAEBLCGmXcfkaANASQtoCZ7t8zZk2AIQvQtpynGkDQPgipDsAOooBQHgipAEAsBQhDQCApQhpAAAsRUgDAGApQhoAAEsR0gAAWIqQ7uAY7AQAOi9CuoNjsBMA6LwI6U6AwU4AoHMipAEAsBQhDQCApQhpAAAsRUgDAGApQhoAAEsR0gAAWIqQBgDAUoQ0AACWIqQBALAUIQ0AgKUIaQAALEVIAwBgKUIaAABLEdKWYX5oAECANSG9aNEiRUREaPbs2c6yyspK5ebmqnv37urWrZsmTpyo0tLSoO87fPiwJkyYIK/Xq9TUVN17772qra0NarNt2zYNHz5csbGxGjhwoPLz89thj1qH+aEBAAFWhPSePXv01FNPadiwYUHL58yZoz/+8Y966aWXtH37dn322We69dZbnfV1dXWaMGGCqqurtXPnTj3zzDPKz8/XAw884LQ5dOiQJkyYoBtvvFHFxcWaPXu2pk2bpo0bN7bb/l0o5ocGAEgWhPSJEyc0adIk/frXv1ZycrKzvLy8XE8//bQWL16s0aNHa8SIEVq9erV27typXbt2SZJef/11vfvuu3ruued09dVXa/z48frFL36hvLw8VVdXS5JWrlypjIwMPfbYYxo8eLBmzpypf//3f9eSJUtarKmqqkoVFRVBLwAA2pvrIZ2bm6sJEyYoKysraHlRUZFqamqCll9++eXq27evCgsLJUmFhYUaOnSo0tLSnDbZ2dmqqKjQgQMHnDaNt52dne1sozkLFy5UYmKi8+rTp89F7ycAABfK1ZBes2aN/vKXv2jhwoVN1pWUlMjj8SgpKSloeVpamkpKSpw2DQM6sD6w7mxtKioqdPr06Wbrmj9/vsrLy53Xp59+2qr9AwDgYkS79YM//fRT3XPPPdq0aZO6dOniVhnNio2NVWxsrNtlAADCnGtn0kVFRTp69KiGDx+u6OhoRUdHa/v27XriiScUHR2ttLQ0VVdXy+/3B31faWmp0tPTJUnp6elNensH3p+rTUJCguLi4tpo7wAAuHiuhfSYMWO0b98+FRcXO69rrrlGkyZNcr6OiYlRQUGB8z0HDx7U4cOHlZmZKUnKzMzUvn37dPToUafNpk2blJCQoCFDhjhtGm4j0CawDQAAbOXa5e74+HhdeeWVQcu6du2q7t27O8unTp2quXPnKiUlRQkJCZo1a5YyMzN1/fXXS5LGjh2rIUOG6I477tDDDz+skpIS3X///crNzXUuV8+YMUPLly/XfffdpzvvvFNbtmzRiy++qPXr17fvDgMAcIFcC+nzsWTJEkVGRmrixImqqqpSdna2nnzySWd9VFSU1q1bp7vvvluZmZnq2rWrpkyZogULFjhtMjIytH79es2ZM0dLly5V79699Zvf/EbZ2dlu7BIAAOfNqpDetm1b0PsuXbooLy9PeXl5LX5Pv3799Nprr511u6NGjdLevXtDUSIAAO3G9eekAQBA8whpAAAsRUh3IMYYJt4AgDBCSHcgPp9P05avV12jWb4AAJ0TId3BRMcxQxYAhAurenfj/DS87G2McbkaAEBbIaQ7oJrKk8p9do+iY6K0NGe42+UAANoIl7s7KI83XjHehCbLjTEqKyvjDBsAOgFCupPx+XzKeXQtvcABoBMgpDuhGC+dywCgMyCkAQCwFCENAIClCGkAACxFSAMAYClCGgAASxHSAABYihHHOgBmvwKA8ERIdwA1p88MA1pfc1p1dcyABQDhgsvdHYTHG68YZsACgLBCSAMAYClCGgAASxHSHZgxRn6/3+0yAABthJDuwGpOndCc/B2qq6UzGQB0RoR0Bxft7ep2CQCANkJIAwBgKUIaAABLMZiJhRqOMGaMcbkaAIBbCGkLBUYYi4qO1ILxA9wuBwDgEi53W8rjjZcUeab3NkOBAkBYIqQtR+9tAAhfhDQAAJYipAEAsBQdxzoBeoMDQOdESHcCNZVneoNHx0Rpac7woHWBAE9OTlZERIRLFQIAWoPL3RYIBOmxY8dUVlbWqm14vPGK8SY0We7z+ZTz6FrnTBsA0HFwJm2BwHPR9TWnVVlRHvIe3THe+JBuDwDQPghpS3i88aqrjlYtM1oBAP6Jy90AAFiKkAYAwFKENAAAliKkAQCwFCHdiRhj5Pf73S4DABAihHQnUnPqxJlZs+ghDgCdAiHdyTBrFgB0HoQ0AACWIqQBALAUIQ0AgKUIaQAALEVIAwBgKUIaAABLEdIAAFiKkAYAwFKENAAAliKkAQCwFCENAIClCGkAACxFSAMAYClCGgAASxHSAABYKtrtAhB6xhj5fD7n6+bWJScnKyIiwo3yAADniTPpTqim8qRyn92jySsK5Pf7g9b5fD7lPLrWCXEAgL04k3ZRwzPeUPN44xUV0/w/b4w3vk1+JgAgtDiTdpHP59O05etVV1frdikAAAsR0i6LjuOsFgDQPEIaAABLEdIAAFiKkAYAwFKENAAAliKkAQCwFCENAIClCOlOzBjTZMQxAEDHQUh3YjWnTmhO/g7V1TJYCgB0RIR0Jxft7ep2CQCAViKkAQCwFCENAIClCGkAACxFSAMAYClCGgAASxHSAABYipAGAMBShDQAAJYipAEAsBQhDQCApQhpAAAsRUgDAGApQjrMGWNUVlYmY4zbpQAAGiGkw5zP51POo2vl8/ncLgUA0IirIb1ixQoNGzZMCQkJSkhIUGZmpv70pz856ysrK5Wbm6vu3burW7dumjhxokpLS4O2cfjwYU2YMEFer1epqam69957Vdto/uRt27Zp+PDhio2N1cCBA5Wfn98eu9dhxHjj3S4BANAMV0O6d+/eWrRokYqKivT2229r9OjRuvnmm3XgwAFJ0pw5c/THP/5RL730krZv367PPvtMt956q/P9dXV1mjBhgqqrq7Vz504988wzys/P1wMPPOC0OXTokCZMmKAbb7xRxcXFmj17tqZNm6aNGze2+/4CAHAhot384d/+9reD3v/3f/+3VqxYoV27dql37956+umn9fzzz2v06NGSpNWrV2vw4MHatWuXrr/+er3++ut69913tXnzZqWlpenqq6/WL37xC82bN08PPfSQPB6PVq5cqYyMDD322GOSpMGDB+vNN9/UkiVLlJ2d3WxdVVVVqqqqct5XVFS00REAAKBl1tyTrqur05o1a3Ty5EllZmaqqKhINTU1ysrKctpcfvnl6tu3rwoLCyVJhYWFGjp0qNLS0pw22dnZqqiocM7GCwsLg7YRaBPYRnMWLlyoxMRE59WnT59Q7ioAAOfF9ZDet2+funXrptjYWM2YMUOvvPKKhgwZopKSEnk8HiUlJQW1T0tLU0lJiSSppKQkKKAD6wPrztamoqJCp0+fbram+fPnq7y83Hl9+umnodhVAAAuiKuXuyVp0KBBKi4uVnl5uf73f/9XU6ZM0fbt212tKTY2VrGxsa7WAACA6yHt8Xg0cOBASdKIESO0Z88eLV26VLfddpuqq6vl9/uDzqZLS0uVnp4uSUpPT9dbb70VtL1A7++GbRr3CC8tLVVCQoLi4uLaarcAALhorl/ubqy+vl5VVVUaMWKEYmJiVFBQ4Kw7ePCgDh8+rMzMTElSZmam9u3bp6NHjzptNm3apISEBA0ZMsRp03AbgTaBbQAAYCtXz6Tnz5+v8ePHq2/fvjp+/Lief/55bdu2TRs3blRiYqKmTp2quXPnKiUlRQkJCZo1a5YyMzN1/fXXS5LGjh2rIUOG6I477tDDDz+skpIS3X///crNzXUuV8+YMUPLly/XfffdpzvvvFNbtmzRiy++qPXr17u56wAAnJOrIX306FFNnjxZR44cUWJiooYNG6aNGzfqG9/4hiRpyZIlioyM1MSJE1VVVaXs7Gw9+eSTzvdHRUVp3bp1uvvuu5WZmamuXbtqypQpWrBggdMmIyND69ev15w5c7R06VL17t1bv/nNb1p8/AoAAFu4GtJPP/30Wdd36dJFeXl5ysvLa7FNv3799Nprr511O6NGjdLevXtbVSMAAG6x7p40AAA4w/Xe3Wh7xhj5fD7nvwCAjoGQDgM1lSeV++we1decVmVFubzd0879TQAA1xHSYcLjjVdddXSTGcIAAPbinjQAAJYipAEAsBQhDQCApQhpAAAsRUgDAGApQhoXxBijsrIyGWPcLgUAOj1CGhfE5/Mp59G1DIoCAO2AkMYFi/HGu10CAIQFQhoAAEsR0gAAWIphQcNQw4k26AAGAPYipMNQYMKN6JgoLc0Z7nY5AIAWENJhyuONV1QM//wAYDPuSQMAYClCGgAASxHSAABYipAGAMBShDSCMDY3ANiDkEYQxuYGAHsQ0miCsbkBwA6ENAAAliKkAQCwFCENAIClCGkAACxFSAMAYClCGgAASxHSYcwYI7/f73YZAIAWMFdhGKs5dUJz8t9XXEqqM3gJI40BgD0I6TAX7e2qmsqTyn12j6JjorQ0Z7jbJQEA/omQhiTJ441XVAy/DgBgE+5JAwBgKUIaAABLEdIAAFiKkAYAwFKENAAAliKkAQCwFCENAIClCGkAACxFSAMAYClCGgAASxHSAABYipAGAMBSrQrp/v3769ixY02W+/1+9e/f/6KLgp2MMc6UlgCAtteqkP7kk09UV1fXZHlVVZX+/ve/X3RRsJPP59O05etVV1vrdikAEBYuaG7CP/zhD87XGzduVGJiovO+rq5OBQUFuvTSS0NWHOwTHRfvdgkAEDYuKKRvueUWSVJERISmTJkStC4mJkaXXnqpHnvssZAVh/ZljJHf73e7DADAP11QSNfX10uSMjIytGfPHvXo0aNNioI7ak6d0Jz89+XtnuZ2KQAAXWBIBxw6dCjUdcAS0d6ubpcAAPinVoW0JBUUFKigoEBHjx51zrADVq1addGFwQ6BHt3JyclulwIAYadVvbt//vOfa+zYsSooKNDnn38un88X9ELn4fP5lPPoWv5dAcAFrTqTXrlypfLz83XHHXeEuh5YKMZLj24AcEOrQrq6ulpf/vKXQ10LLNFw0BJjjMvVAED4atXl7mnTpun5558PdS2wRE3lSeU+u0eTVxTwSBYAuKhVZ9KVlZX61a9+pc2bN2vYsGGKiYkJWr948eKQFAf3eLzxioppdb9CAEAItOpT+J133tHVV18tSdq/f3/QuoiIiIsuCgAAtDKkt27dGuo6AABAI0xVCQCApVp1Jn3jjTee9bL2li1bWl0QAAA4o1UhHbgfHVBTU6Pi4mLt37+/ycQbAACgdVoV0kuWLGl2+UMPPaQTJ05cVEEAAOCMkN6T/u53v8u43QAAhEhIQ7qwsFBdunQJ5SYBAAhbrbrcfeuttwa9N8boyJEjevvtt/Wzn/0sJIUBABDuWhXSiYmJQe8jIyM1aNAgLViwQGPHjg1JYQAAhLtWhfTq1atDXQcAAGjkogZnLioq0nvvvSdJuuKKK/SlL30pJEUBAIBWhvTRo0eVk5Ojbdu2KSkpSZLk9/t14403as2aNbrkkktCWSMAAGGpVb27Z82apePHj+vAgQMqKytTWVmZ9u/fr4qKCv3whz8MdY0AAISlVp1Jb9iwQZs3b9bgwYOdZUOGDFFeXh4dxwAACJFWnUnX19c3mUNakmJiYlRfX3/RRQEAgFaG9OjRo3XPPffos88+c5b9/e9/15w5czRmzJiQFQcAQDhrVUgvX75cFRUVuvTSSzVgwAANGDBAGRkZqqio0LJly0JdIwAAYalV96T79Omjv/zlL9q8ebPef/99SdLgwYOVlZUV0uIAAAhnF3QmvWXLFg0ZMkQVFRWKiIjQN77xDc2aNUuzZs3StddeqyuuuEI7duxoq1oBAAgrFxTSjz/+uKZPn66EhIQm6xITE/WDH/xAixcvDllxAACEswsK6b/+9a8aN25ci+vHjh2roqKiiy4KdjDGyO/3u10GAIStCwrp0tLSZh+9CoiOjtY//vGPiy4Kdqg5dUJz8neorrbW7VIAICxdUEh/4Qtf0P79+1tc/84776hnz54XXRTsEe3t6nYJABC2Liikv/nNb+pnP/uZKisrm6w7ffq0HnzwQX3rW98KWXEAAISzC3oE6/7779fLL7+sL37xi5o5c6YGDRokSXr//feVl5enuro6/fSnP22TQgEACDcXFNJpaWnauXOn7r77bs2fP1/GGElSRESEsrOzlZeXp7S0tDYpFO4xxsjn87ldBgCEnQsezKRfv3567bXX5PP59OGHH8oYo8suu0zJycltUR8sUFN5UrnP7lF9zWnV1dUqyu2CACBMtGrEMUlKTk7WtddeG8paYDGPN1511dGqPc4ZNQC0l1aN3R0qCxcu1LXXXqv4+Hilpqbqlltu0cGDB4PaVFZWKjc3V927d1e3bt00ceJElZaWBrU5fPiwJkyYIK/Xq9TUVN17772qbfTY0LZt2zR8+HDFxsZq4MCBys/Pb+vdAwDgorga0tu3b1dubq527dqlTZs2qaamRmPHjtXJkyedNnPmzNEf//hHvfTSS9q+fbs+++wz3Xrrrc76uro6TZgwQdXV1dq5c6eeeeYZ5efn64EHHnDaHDp0SBMmTNCNN96o4uJizZ49W9OmTdPGjRvbdX8BALgQrb7cHQobNmwIep+fn6/U1FQVFRXphhtuUHl5uZ5++mk9//zzGj16tCRp9erVGjx4sHbt2qXrr79er7/+ut59911t3rxZaWlpuvrqq/WLX/xC8+bN00MPPSSPx6OVK1cqIyNDjz32mKQzk4G8+eabWrJkibKzs5vUVVVVpaqqKud9RUVFGx4FAACa5+qZdGPl5eWSpJSUFElSUVGRampqgmbXuvzyy9W3b18VFhZKkgoLCzV06NCgXuXZ2dmqqKjQgQMHnDaNZ+jKzs52ttHYwoULlZiY6Lz69OkTup0EAOA8WRPS9fX1mj17tr7yla/oyiuvlCSVlJTI4/EoKSkpqG1aWppKSkqcNo0f+wq8P1ebiooKnT59ukkt8+fPV3l5ufP69NNPQ7KPAABcCFcvdzeUm5ur/fv3680333S7FMXGxio2NtbtMqzV8Lnp5ORkRUREuFwRAHROVpxJz5w5U+vWrdPWrVvVu3dvZ3l6erqqq6ubzMRUWlqq9PR0p03j3t6B9+dqk5CQoLi4uFDvTqcXeG568ooCBjkBgDbkakgbYzRz5ky98sor2rJlizIyMoLWjxgxQjExMSooKHCWHTx4UIcPH1ZmZqYkKTMzU/v27dPRo0edNps2bVJCQoKGDBnitGm4jUCbwDZw4TzeeMV4m84rDgAIHVcvd+fm5ur555/Xq6++qvj4eOcecmJiouLi4pSYmKipU6dq7ty5SklJUUJCgmbNmqXMzExdf/31ks7MYT1kyBDdcccdevjhh1VSUqL7779fubm5ziXrGTNmaPny5brvvvt05513asuWLXrxxRe1fv161/YdAIBzcfVMesWKFSovL9eoUaPUs2dP5/XCCy84bZYsWaJvfetbmjhxom644Qalp6fr5ZdfdtZHRUVp3bp1ioqKUmZmpr773e9q8uTJWrBggdMmIyND69ev16ZNm3TVVVfpscce029+85tmH78CAMAWrp5JByboOJsuXbooLy9PeXl5LbYJjCd+NqNGjdLevXsvuEYAANxiRccxdFzGGJWVlZ3XH1wAgAtDSOOi+P1+5Ty6ll7eANAGCGlctBhvvNslAECnREgDAGApQhoAAEsR0gAAWIqQBgDAUoQ0AACWIqQBALAUIQ0AgKUIaQAALEVIAwBgKUIaAABLEdIAAFiKkAYAwFKENAAAliKk0WrGGPn9frfLAIBOi5BGq9WcOqE5+TtUV1vrdikA0CkR0rgo0d6ubpcAAJ0WIQ0AgKUIaQAALEVII6SMMSorK5Mxxu1SAKDDI6QRUj6fTzmPrpXP53O7FADo8AhphFyMN97tEgCgUyCkAQCwFCENAIClCGkAACxFSAMAYClCGgAAS0W7XQA6PmOM88gVz0cDQOgQ0rhoNZUnlfvsHkXHRGlpznC3ywGAToOQRkh4vPGKiuHXCQBCiXvSAABYipAGAMBShDQAAJYipAEAsBQhDQCApQhphIwxRn6/3+0yAKDTIKQRMjWnTmhO/g7V1da6XQoAdAqENEIq2tvV7RIAoNMgpAEAsBQhjTZljFFZWRljegNAKxDSaFM+n085j651JuAAAJw/QhptLsYb73YJANAhEdIAAFiKkAYAwFKENAAAliKkAQCwFCENAIClCGkAACwV7XYB6HyMMc5z0QxiAgCtR0gj5GoqTyr32T2KjonS0pzhbpcDAB0WIY024fHGKyqGXy8AuBjck0a7YixvADh/hDTaFWN5A8D543qkCwIdq8I1qBjLGwDODyHtAp/Pp8krClR96oTq6mrdLgcAYCkud7skxpugmLjOfUZpjJHf73e7DADosAhptJmaUyc0J3+H6mq5WgAArUFIo01Fe7u6XQIAdFiENAAAliKkAQCwFCENAIClCGkAACxFSAMAYClCGgAASxHScBUTbgBAywhpuIoJNwCgZYQ0XMeEGwDQPEIaAABLMQsW2lxgas7A1wCA80NIo83VVJ5U7rN7FB0TpaU5w90uBwA6DEIa7cLjjVdkdBRTVwLABeCeNNoNU1cCwIUhpNGumLoSAM4fIQ0AgKUIaQAALEVIAwBgKUIaAABLEdIAAFiKkAYAwFIMZoJ2xzChAHB+CGm0O4YJBYDzw+VuuMLjjVeMN6HJcmOMysrKOMMGABHSsIzP51POo2udy+EAEM4IabjGGNPshBsx3vj2LwYALERIwzVMuAEAZ+dqSL/xxhv69re/rV69eikiIkJr164NWm+M0QMPPKCePXsqLi5OWVlZ+uCDD4LalJWVadKkSUpISFBSUpKmTp2qEydOBLV555139LWvfU1dunRRnz599PDDD7f1ruE8MeEGALTM1ZA+efKkrrrqKuXl5TW7/uGHH9YTTzyhlStXavfu3eratauys7NVWVnptJk0aZIOHDigTZs2ad26dXrjjTd01113OesrKio0duxY9evXT0VFRXrkkUf00EMP6Ve/+lWb7x8AABfD1Uewxo8fr/Hjxze7zhijxx9/XPfff79uvvlmSdJvf/tbpaWlae3atcrJydF7772nDRs2aM+ePbrmmmskScuWLdM3v/lNPfroo+rVq5d+97vfqbq6WqtWrZLH49EVV1yh4uJiLV68OCjMG6qqqlJVVZXzvqKiIsR7DgDAuVl7T/rQoUMqKSlRVlaWsywxMVEjR45UYWGhJKmwsFBJSUlOQEtSVlaWIiMjtXv3bqfNDTfcII/H47TJzs7WwYMHW+xBvHDhQiUmJjqvPn36tMUuAgBwVtaGdElJiSQpLS0taHlaWpqzrqSkRKmpqUHro6OjlZKSEtSmuW00/BmNzZ8/X+Xl5c7r008/vfgdAgDgAjHiWDNiY2MVGxvrdhkAgDBn7Zl0enq6JKm0tDRoeWlpqbMuPT1dR48eDVpfW1ursrKyoDbNbaPhz4C7AmN5M9IYAASzNqQzMjKUnp6ugoICZ1lFRYV2796tzMxMSVJmZqb8fr+KioqcNlu2bFF9fb1GjhzptHnjjTdUU1PjtNm0aZMGDRqk5OTkdtobnE1gLO/JKwqaHdwEAMKVqyF94sQJFRcXq7i4WNKZzmLFxcU6fPiwIiIiNHv2bP3yl7/UH/7wB+3bt0+TJ09Wr169dMstt0iSBg8erHHjxmn69Ol666239Oc//1kzZ85UTk6OevXqJUn6zne+I4/Ho6lTp+rAgQN64YUXtHTpUs2dO9elvUZzWhrLGwDCmav3pN9++23deOONzvtAcE6ZMkX5+fm67777dPLkSd11113y+/366le/qg0bNqhLly7O9/zud7/TzJkzNWbMGEVGRmrixIl64oknnPWJiYl6/fXXlZubqxEjRqhHjx564IEHWnz8CgAAW7ga0qNGjTrrPciIiAgtWLBACxYsaLFNSkqKnn/++bP+nGHDhmnHjh2trhMAADdYe08aAIBwR0jDGi3NigUA4YqQhjWYFQsAghHSsAqzYgHAvxDSsJoxhkFOAIQtQhrWaTgCWVlZmXIeXdviZCgA0JkxdjesExiBLDomSktzhivGG+92SQDgCkIaVvJ44xUVw68ngPDG5W4AACxFSAMAYClCGtZicBMA4Y6QhrUY3ARAuCOkYTUGNwEQzghpAAAsRUgDAGApQhoAAEsR0uhQGMsbQDghpNGh+Hw+xvIGEDYYdxHWC0y4EfgvY3kDCBeENKwXmHCjvua0KivK5e2e5nZJANAuCGl0CB5vvOqqo1XLwCYAwgj3pAEAsBQhDQCApQhpAAAsRUgDAGApQhoAAEsR0gAAWIpHsNDhNBzcxBijiIgIpaSkKCIiwu3SACCkCGl0OI0HN4mJ66qXfjJRKSkpbpcGACFFSKNDaji4SUxcN7fLAYA2wT1pAAAsRUijU2EqSwCdCSGNToWpLAF0JtyTRocX6O0d+JqpLAF0FoQ0OrxAb+/omCgtzRnudjkAEDKENDoFjzdeUTH8OgPoXLgnjU7DGCO/3+92GQAQMoQ0Oo2aUyc0J3+H6mpr3S4FAEKCkEanEu3t6nYJABAyhDQAAJYipAEAsBQhDQCApQhpdGqNhwll2FAAHQkhjU6t8TChDBsKoCNh9Ad0OucaJjQ6rpuzPjk5WREREe1eIwCcD0IanU5Lw4QGwrvm9Jn1UdGRWpozXMnJyUpJSSGsAViHy93olDzeeMV4E4KW+Xw+TVu+XnV1tfJ44yVF6s6l6/Qf/+9lLn8DsBIhjU6ruWFCo+MaXfr2dlWMt1s7VgUA54/L3ei0zgwT+r7iUlI5UwbQIRHS6NSivV2de9T1NadVV8e43gA6Di53Iyx4vPGKaXSpGwBsR0gDAGApQhoAAEsR0gAAWIqQBgDAUoQ0AACW4hEshL3AcKGBmbEYIhSALQhphL2Gz1HXVtfo1zPGqH///gQ1ANdxuRtQg+eoIyN011NbGKEMgBUIaaCRhmN5G2NUVlbmXAoHgPZESANn4fP5lPPoWs6sAbiCkAbOIcbLcKIA3EHHMaCRQG/vwNcA4BZCGmgk0Ns7OiZKS3OGu10OgDBGSAPN8HjjFRXD/x4A3MU9aQAALEVIAy0wxsjv9zdZxiNZANoLIQ20oObUCc3J36Hamhr5fD6VlZWprKyMR7IAtBtuugFnEe3t6nQki4qO1ILxA3gkC0C74UwaOA8eb7ykSM3J36G62lpnOZe/AbQlQhq4ANHerkHvGZEMQFsipNtZw4Ey0LEF/i25/A2grRDS7czn82na8vWqq6k9d2NYKRDOH3/8saYtXx/UsYzL3gBCiY5jLoiO48yrI2s4/3SEJy5ohLJnZoxWRESEkpOTmY8awEXjTBpoBWf+6Qbvo+Pi9cknnwTdo6ZjGYCLQUgDIRJ4rjrSE+cso2MZgIvB5W4ghBr3/pak6LhuTkhzGRzAhSCkgRALdCwL/LfmdMv3rANtCG8AzSGkgRBr2LGssqJc0d6uivPGKzI6Sp988ol+8r/FWjltlJKTk2WM0e2Pvao1P75FKSkpbpcOwDKENNAGPN541VVHq7bB6GRn7lm/r2hv1/MaZpSzbAB0HAPaUeCedcNhRgPPWR87dkyff/65Pv/8cx07dkwff/yxch5d60zsUVZWpvr6+qDe4vQeBzo3zqQBFzWcwCNweVySvN3TzzyHHdNFn3zyiR7c8LGMMVowfoBzuTwpKUk+n0//ueoNLpcDnRQhDVig8eXxwPvTx32ak79DSX0uU131ac3J3+FcLg+EelxKqtN7PCkpSX6/n0vkQCfB5W7Acg0f62p4uTwmLj7oTHzyigIdOnRItz3yij7++GMugwOdACENdAKBEc/8fr8UEancZ/fojic36+OPPw4Kau5hAx0LIQ10EoERz+rqap2OadNXFuijjz7SsWPHnM5nDc+0G3dEA2AX7kkDnUiTEc8iI3Tn0nXydk93HvkKnGkH3jfsiGaMce5lB4K74b3t5OTkM2frYvQ0oD2EVUjn5eXpkUceUUlJia666iotW7ZM1113ndtlAW0q2tv1nx3R/tXxLK7R+4Yd0QI9yxv3NK+trtEjOSOcnuZP3D5CGRkZTmgHepsHBAI/JSWl1WHecP51/ihAOAqbkH7hhRc0d+5crVy5UiNHjtTjjz+u7OxsHTx4UKmpqW6XB7SLxmfaDTuiBXqXt9TTvLY2uKf59JUFQaG9YPwA/fi5Pys2MfVfo6118erXM8Y4Z+mSWjxTb+693+/X7Bf2yhijpTnDlZSUFLQ+8EeAJOeyfXNn/saYoO1HRESc9Y8K6czkKIHe8oHvDcUfHMnJyU69DbfZ3OA1Df9Iadxz/1zt+aOmcwibkF68eLGmT5+u73//+5KklStXav369Vq1apV+8pOfBLWtqqpSVVWV8768/MwZRUVFxUXXUVFRoUr/P2SMUX1NpSqPn9lmZGQU73lvxftztT3lK3Xez1i2Vok9M1RfU6kZy/YpJrarorqcUn1NpWqrz7zufGL9mfanjkuSvImXXND7wPZv/+UzTdbHdOmqJ2eMkyTNWLZWdXV1zvq6+notvD1TP//DftWcPhm0/ajoKD1405Wa99utiu2W0uz2Zv3qdf3ytpHO99fV1+vJGeOUlJSk1vD7/Zr1q9e17K6xTr2RMV2cbTZcH/gZfr9fc5/9syTpwZuu1P0v7HbWn6v94ju+0upacXahHJMgPj7+7H9MmTBQVVVloqKizCuvvBK0fPLkyeamm25q0v7BBx80knjx4sWLF682fZWXl581v8LiTPrzzz9XXV2d0tLSgpanpaXp/fffb9J+/vz5mjt3rvM+0AO2e/furb58VFFRoT59+ujTTz9VQkJCq7bR3jpizVLHrLsj1ixRd3vqiDVLHbPu9qw5Pr75sfsDwiKkL1RsbKxiY2ODloXqslFCQkKH+UUN6Ig1Sx2z7o5Ys0Td7akj1ix1zLptqDksnpPu0aOHoqKiVFpaGrS8tLRU6enpLlUFAMDZhUVIezwejRgxQgUFBc6y+vp6FRQUKDMz08XKAABoWdhc7p47d66mTJmia665Rtddd50ef/xxnTx50unt3dZiY2P14IMPNrmMbrOOWLPUMevuiDVL1N2eOmLNUses26aaI4wJn/EAly9f7gxmcvXVV+uJJ57QyJEj3S4LAIBmhVVIAwDQkYTFPWkAADoiQhoAAEsR0gAAWIqQBgDAUoR0O8jLy9Oll16qLl26aOTIkXrrrbdcq2XhwoW69tprFR8fr9TUVN1yyy06ePBgUJtRo0Y5MwUFXjNmzAhqc/jwYU2YMEFer1epqam69957nVmT2sJDDz3UpKbLL7/cWV9ZWanc3Fx1795d3bp108SJE5sMXtPeNV966aVNao6IiFBubq4ke47zG2+8oW9/+9vq1auXIiIitHbt2qD1xhg98MAD6tmzp+Li4pSVlaUPPvggqE1ZWZkmTZqkhIQEJSUlaerUqTpx4kRQm3feeUdf+9rX1KVLF/Xp00cPP/xwm9VdU1OjefPmaejQoeratat69eqlyZMn67PPPgvaRnP/RosWLWqzus91rL/3ve81qWfcuHFBbWw71pKa/T2PiIjQI4884rRp72N9Pp91ofrc2LZtm4YPH67Y2FgNHDhQ+fn5ra67iRDNYYEWrFmzxng8HrNq1Spz4MABM336dJOUlGRKS0tdqSc7O9usXr3a7N+/3xQXF5tvfvObpm/fvubEiRNOm69//etm+vTp5siRI86r4SDwtbW15sorrzRZWVlm79695rXXXjM9evQw8+fPb7O6H3zwQXPFFVcE1fSPf/zDWT9jxgzTp08fU1BQYN5++21z/fXXmy9/+cuu1nz06NGgejdt2mQkma1btxpj7DnOr732mvnpT39qXn75ZSOpyUQ0ixYtMomJiWbt2rXmr3/9q7nppptMRkaGOX36tNNm3Lhx5qqrrjK7du0yO3bsMAMHDjS33367s768vNykpaWZSZMmmf3795vf//73Ji4uzjz11FNtUrff7zdZWVnmhRdeMO+//74pLCw01113nRkxYkTQNvr162cWLFgQ9G/Q8P+FUNd9rmM9ZcoUM27cuKB6ysrKgtrYdqyNMUH1HjlyxKxatcpERESYjz76yGnT3sf6fD7rQvG58fHHHxuv12vmzp1r3n33XbNs2TITFRVlNmzY0Kq6GyOk29h1111ncnNznfd1dXWmV69eZuHChS5W9S9Hjx41ksz27dudZV//+tfNPffc0+L3vPbaayYyMtKUlJQ4y1asWGESEhJMVVVVm9T54IMPmquuuqrZdX6/38TExJiXXnrJWfbee+8ZSaawsNC1mhu75557zIABA0x9fb0xxs7j3PgDuL6+3qSnp5tHHnnEWeb3+01sbKz5/e9/b4wx5t133zWSzJ49e5w2f/rTn0xERIT5+9//bowx5sknnzTJyclBdc+bN88MGjSoTepuzltvvWUkmb/97W/Osn79+pklS5a0+D1tWXdLIX3zzTe3+D0d5VjffPPNZvTo0UHL3DzWxjT9rAvV58Z9991nrrjiiqCfddttt5ns7OyQ1M3l7jZUXV2toqIiZWVlOcsiIyOVlZWlwsJCFyv7l8Bc2Y3nR/3d736nHj166Morr9T8+fN16tQpZ11hYaGGDh0aNKtYdna2KioqdODAgTar9YMPPlCvXr3Uv39/TZo0SYcPH5YkFRUVqaamJug4X3755erbt69znN2qOaC6ulrPPfec7rzzzqCZ1Gw8zg0dOnRIJSUlQcc2MTFRI0eODDq2SUlJuuaaa5w2WVlZioyM1O7du502N9xwgzweT9C+HDx4UD6fr132pby8XBEREU0my1m0aJG6d++uL33pS3rkkUeCLmW6Ufe2bduUmpqqQYMG6e6779axY8eC6rH9WJeWlmr9+vWaOnVqk3VuHuvGn3Wh+twoLCwM2kagTag+48NmWFA3XOgUme2tvr5es2fP1le+8hVdeeWVzvLvfOc76tevn3r16qV33nlH8+bN08GDB/Xyyy9LkkpKSprdp8C6tjBy5Ejl5+dr0KBBOnLkiH7+85/ra1/7mvbv36+SkhJ5PJ4mH75paWlOPW7U3NDatWvl9/v1ve99z1lm43FuLPBzmquj4bFNTU0NWh8dHa2UlJSgNhkZGU22EViXnJzcJvUHVFZWat68ebr99tuDZjX64Q9/qOHDhyslJUU7d+7U/PnzdeTIES1evNiVuseNG6dbb71VGRkZ+uijj/Rf//VfGj9+vAoLCxUVFdUhjvUzzzyj+Ph43XrrrUHL3TzWzX3Whepzo6U2FRUVOn36tOLi4lpdt0RIh7Xc3Fzt379fb775ZtDyu+66y/l66NCh6tmzp8aMGaOPPvpIAwYMaO8yJUnjx493vh42bJhGjhypfv366cUXX7zo/wnaw9NPP63x48erV69ezjIbj3NnVFNTo//4j/+QMUYrVqwIWtdw3vhhw4bJ4/HoBz/4gRYuXOjKuM05OTnO10OHDtWwYcM0YMAAbdu2TWPGjGn3elpj1apVmjRpkrp06RK03M1j3dJnXUfA5e42ZPMUmTNnztS6deu0detW9e7d+6xtA+Obf/jhh5Kk9PT0ZvcpsK49JCUl6Ytf/KI+/PBDpaenq7q6Wn6/v0lNgXrcrPlvf/ubNm/erGnTpp21nY3HOfBzzvY7nJ6erqNHjwatr62tVVlZmevHPxDQf/vb37Rp06Zzzg08cuRI1dbW6pNPPnFqc/PfoH///urRo0fQ74Stx1qSduzYoYMHD57zd11qv2Pd0mddqD43WmqTkJAQkhMIQroN2ThFpjFGM2fO1CuvvKItW7Y0ubzUnOLiYklSz549JUmZmZnat29f0IdF4ANwyJAhbVJ3YydOnNBHH32knj17asSIEYqJiQk6zgcPHtThw4ed4+xmzatXr1ZqaqomTJhw1nY2HueMjAylp6cHHduKigrt3r076Nj6/X4VFRU5bbZs2aL6+nrnD4/MzEy98cYbqqmpCdqXQYMGtdnl10BAf/DBB9q8ebO6d+9+zu8pLi5WZGSkc0nZjbob+r//+z8dO3Ys6HfCxmMd8PTTT2vEiBG66qqrztm2rY/1uT7rQvW5kZmZGbSNQJuQfcaHpPsZWrRmzRoTGxtr8vPzzbvvvmvuuusuk5SUFNRbsD3dfffdJjEx0Wzbti3oUYhTp04ZY4z58MMPzYIFC8zbb79tDh06ZF599VXTv39/c8MNNzjbCDyWMHbsWFNcXGw2bNhgLrnkkjZ9nOlHP/qR2bZtmzl06JD585//bLKyskyPHj3M0aNHjTFnHqXo27ev2bJli3n77bdNZmamyczMdLVmY8705u/bt6+ZN29e0HKbjvPx48fN3r17zd69e40ks3jxYrN3716nF/SiRYtMUlKSefXVV80777xjbr755mYfwfrSl75kdu/ebd58801z2WWXBT0W5Pf7TVpamrnjjjvM/v37zZo1a4zX672ox4LOVnd1dbW56aabTO/evU1xcXHQ73qgV+7OnTvNkiVLTHFxsfnoo4/Mc889Zy655BIzefLkNqv7bDUfP37c/PjHPzaFhYXm0KFDZvPmzWb48OHmsssuM5WVlc42bDvWAeXl5cbr9ZoVK1Y0+X43jvW5PuuMCc3nRuARrHvvvde89957Ji8vj0ewOpply5aZvn37Go/HY6677jqza9cu12qR1Oxr9erVxhhjDh8+bG644QaTkpJiYmNjzcCBA829994b9PyuMcZ88sknZvz48SYuLs706NHD/OhHPzI1NTVtVvdtt91mevbsaTwej/nCF75gbrvtNvPhhx8660+fPm3+8z//0yQnJxuv12v+7d/+zRw5csTVmo0xZuPGjUaSOXjwYNBym47z1q1bm/2dmDJlijHmzGNYP/vZz0xaWpqJjY01Y8aMabI/x44dM7fffrvp1q2bSUhIMN///vfN8ePHg9r89a9/NV/96ldNbGys+cIXvmAWLVrUZnUfOnSoxd/1wHPqRUVFZuTIkSYxMdF06dLFDB482PzP//xPUCCGuu6z1Xzq1CkzduxYc8kll5iYmBjTr18/M3369CZ/0Nt2rAOeeuopExcXZ/x+f5Pvd+NYn+uzzpjQfW5s3brVXH311cbj8Zj+/fsH/YyLxVSVAABYinvSAABYipAGAMBShDQAAJYipAEAsBQhDQCApQhpAAAsRUgDAGApQhoAAEsR0gAAWIqQBgDAUoQ0AACW+v/XfUaUz6/OiAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "sns.displot(tokds[\"train\"][\"length\"])" ] }, { "cell_type": "code", "execution_count": 18, "id": "d11597f9-0441-440c-8214-b9d8b2df6f79", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "46d3909d41c649acb800d4bf00197951", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map (num_proc=16): 0%| | 0/160800 [00:00 1024, num_proc=16)" ] }, { "cell_type": "code", "execution_count": 25, "id": "2b6d57f7-40b7-4417-88bc-83c63b22f153", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "31" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(filttokds[\"test\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "78d391fa-9a57-446b-9007-fe64ef8fc735", "metadata": {}, "outputs": [], "source": [ "tokds = ds.map(lambda x: tokenizer(x['prompt']), num_proc=16)" ] }, { "cell_type": "code", "execution_count": 31, "id": "176dbd05-67c5-45a6-b891-1237deb7d6c9", "metadata": {}, "outputs": [], "source": [ "ds = load_dataset(\"mnoukhov/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144\")" ] }, { "cell_type": "code", "execution_count": 32, "id": "9eab3eaa-55ed-4279-96d6-3c189266ba86", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "911df020ac294d9ca2b360e1a0be3f93", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Filter: 0%| | 0/116722 [00:00