diff --git "a/test.ipynb" "b/test.ipynb" --- "a/test.ipynb" +++ "b/test.ipynb" @@ -8,6 +8,9 @@ "source": [ "%load_ext autoreload\n", "%autoreload 2\n", + "import os\n", + "\n", + "os.environ['TORCH_LOGS'] = 'dynamic'\n", "\n", "import pylab as pl" ] @@ -33,7 +36,7 @@ "text/html": [ "\n", " \n", " " @@ -73,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -88,7 +91,7 @@ "text/html": [ "\n", " \n", " " @@ -160,7 +163,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -171,7 +174,10 @@ "torch.Size([1, 640, 143])\n", "torch.Size([1, 512, 143])\n", "torch.Size([1, 143, 348])\n", - "torch.Size([1, 640, 348])\n", + "en.shape=torch.Size([1, 640, 348])\n", + "s.shape=torch.Size([1, 128])\n", + "en.dtype=torch.float32\n", + "s.dtype=torch.float32\n", "torch.Size([1, 512, 143])\n", "torch.Size([1, 512, 348])\n" ] @@ -179,10 +185,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 4, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, @@ -202,7 +208,11 @@ "print(d.transpose(-1,-2).shape)\n", "print(d_en.shape)\n", "print(pred_aln_trg.unsqueeze(0).shape)\n", - "print(en.shape)\n", + "print(f\"{en.shape=}\")\n", + "print(f\"{s.shape=}\")\n", + "print(f\"{en.dtype=}\")\n", + "print(f\"{s.dtype=}\")\n", + "\n", "print(t_en.shape)\n", "print(asr.shape)\n", "pl.imshow(pred_aln_trg[:,:])\n" @@ -217,25 +227,55 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "W1231 22:39:38.288000 1824615 site-packages/torch/fx/experimental/symbolic_shapes.py:5124] failed during evaluate_expr(u0, hint=None, size_oblivious=False, forcing_spec=False\n", - "E1231 22:39:38.291000 1824615 site-packages/torch/fx/experimental/recording.py:298] failed while running evaluate_expr(*(u0, None), **{'fx_node': False})\n" + "I0101 18:56:11.998000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3557] create_symbol s0 = 143 for L['args'][0][0].size()[1] [2, 510] (_export/non_strict_utils.py:109 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"s0\"\n", + "I0101 18:56:12.007000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:4857] set_replacement s0 = 143 (range_refined_to_singleton) VR[143, 143]\n", + "I0101 18:56:12.008000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5106] eval Eq(s0, 143) [guard added] (mp/ipykernel_2488298/2554868606.py:17 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"Eq(s0, 143)\"\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I0101 18:56:27.383000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3317] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:390 in local_scalar_dense)\n", + "I0101 18:56:27.387000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5106] runtime_assert u0 >= 0 [guard added] (_refs/__init__.py:4957 in arange), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"u0 >= 0\"\n", + "W0101 18:56:28.575000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5124] failed during evaluate_expr(u0, hint=None, size_oblivious=False, forcing_spec=False\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] failed while running evaluate_expr(*(u0, None), **{'fx_node': False})\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] Traceback (most recent call last):\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/recording.py\", line 262, in wrapper\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] return retlog(fn(*args, **kwargs))\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py\", line 5122, in evaluate_expr\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py\", line 5238, in _evaluate_expr\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] raise self._make_data_dependent_error(\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] \n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] Potential framework code culprit (scroll up for full backtrace):\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] return self._op_dk(dk, *args, **kwargs)\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] \n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For more information, run with TORCH_LOGS=\"dynamic\"\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] \n", + "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n" ] }, { "ename": "GuardOnDataDependentSymNode", - "evalue": "Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)\n\nPotential framework code culprit (scroll up for full backtrace):\n File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n return self._op_dk(dk, *args, **kwargs)\n\nFor more information, run with TORCH_LOGS=\"dynamic\"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n File \"/rhome/eingerman/Projects/DeepLearning/TTS/Kokoro-82M/models.py\", line 470, in F0Ntrain\n torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.size(2)}\"))\n\nTo fix the error, insert one of the following checks before this call:\n 1. torch._check(x.shape[2])\n 2. torch._check(~x.shape[2])\n\n(These suggested fixes were derived by replacing `u0` with x.shape[2] or x1.shape[1] in u0 and its negation.)", + "evalue": "Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)\n\nPotential framework code culprit (scroll up for full backtrace):\n File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n return self._op_dk(dk, *args, **kwargs)\n\nFor more information, run with TORCH_LOGS=\"dynamic\"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n File \"/rhome/eingerman/Projects/DeepLearning/TTS/Kokoro-82M/models.py\", line 471, in F0Ntrain\n x2, _temp = self.shared(x1)\n\nTo fix the error, insert one of the following checks before this call:\n 1. torch._check(x.shape[2])\n 2. torch._check(~x.shape[2])\n\n(These suggested fixes were derived by replacing `u0` with x.shape[2] or x1.shape[1] in u0 and its negation.)", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mGuardOnDataDependentSymNode\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[23], line 54\u001b[0m\n\u001b[1;32m 51\u001b[0m dynamic_shapes \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens\u001b[39m\u001b[38;5;124m\"\u001b[39m:{\u001b[38;5;241m1\u001b[39m:token_len}}\n\u001b[1;32m 53\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[0;32m---> 54\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstyle_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[39], line 61\u001b[0m\n\u001b[1;32m 58\u001b[0m dynamic_shapes \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens\u001b[39m\u001b[38;5;124m\"\u001b[39m:{\u001b[38;5;241m0\u001b[39m:batch, \u001b[38;5;241m1\u001b[39m:token_len}}\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstyle_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;66;03m# export_mod = torch.export.export(style_model, args=( tokens, ), strict=False)\u001b[39;00m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/__init__.py:270\u001b[0m, in \u001b[0;36mexport\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature)\u001b[0m\n\u001b[1;32m 264\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(mod, torch\u001b[38;5;241m.\u001b[39mjit\u001b[38;5;241m.\u001b[39mScriptModule):\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExporting a ScriptModule is not supported. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMaybe try converting your ScriptModule to an ExportedProgram \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124musing `TS2EPConverter(mod, args, kwargs).convert()` instead.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 269\u001b[0m )\n\u001b[0;32m--> 270\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_export\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstrict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 276\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 278\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1017\u001b[0m, in \u001b[0;36m_log_export_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1010\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1011\u001b[0m log_export_usage(\n\u001b[1;32m 1012\u001b[0m event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.error.unclassified\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1013\u001b[0m \u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39merror_type,\n\u001b[1;32m 1014\u001b[0m message\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m 1015\u001b[0m flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m 1016\u001b[0m )\n\u001b[0;32m-> 1017\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 1018\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 1019\u001b[0m _EXPORT_FLAGS \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:990\u001b[0m, in \u001b[0;36m_log_export_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 989\u001b[0m start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 990\u001b[0m ep \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 991\u001b[0m end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 992\u001b[0m log_export_usage(\n\u001b[1;32m 993\u001b[0m event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.time\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 994\u001b[0m metrics\u001b[38;5;241m=\u001b[39mend \u001b[38;5;241m-\u001b[39m start,\n\u001b[1;32m 995\u001b[0m flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m 996\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mget_ep_stats(ep),\n\u001b[1;32m 997\u001b[0m )\n", @@ -256,8 +296,8 @@ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1598\u001b[0m, in \u001b[0;36m_non_strict_export.._tuplify_outputs.._aot_export_non_strict..Wrapper.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1594\u001b[0m tree_out \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mfx\u001b[38;5;241m.\u001b[39mInterpreter(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_export_root)\u001b[38;5;241m.\u001b[39mrun(\n\u001b[1;32m 1595\u001b[0m \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m 1596\u001b[0m )\n\u001b[1;32m 1597\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1598\u001b[0m tree_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_export_root\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1599\u001b[0m flat_outs, out_spec \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_flatten(tree_out)\n\u001b[1;32m 1600\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtuple\u001b[39m(flat_outs)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", - "Cell \u001b[0;32mIn[23], line 40\u001b[0m, in \u001b[0;36mStyleTTS2.forward\u001b[0;34m(self, tokens)\u001b[0m\n\u001b[1;32m 36\u001b[0m pred_aln_trg\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mvstack(pred_aln_trg_list)\n\u001b[1;32m 38\u001b[0m en \u001b[38;5;241m=\u001b[39m d\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m@\u001b[39m pred_aln_trg\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m---> 40\u001b[0m F0_pred, N_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpredictor\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mF0Ntrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43men\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 41\u001b[0m t_en \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext_encoder\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39minference(tokens)\n\u001b[1;32m 42\u001b[0m asr \u001b[38;5;241m=\u001b[39m t_en \u001b[38;5;241m@\u001b[39m pred_aln_trg\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n", - "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/models.py:470\u001b[0m, in \u001b[0;36mProsodyPredictor.F0Ntrain\u001b[0;34m(self, x, s)\u001b[0m\n\u001b[1;32m 468\u001b[0m torch\u001b[38;5;241m.\u001b[39m_check(x1\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m, \u001b[38;5;28;01mlambda\u001b[39;00m: \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mShape 2, got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mx1\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m2\u001b[39m)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 469\u001b[0m torch\u001b[38;5;241m.\u001b[39m_check(x\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;28;01mlambda\u001b[39;00m: \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mShape 2, got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mx\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m2\u001b[39m)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m))\n\u001b[0;32m--> 470\u001b[0m x, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshared\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx1\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 471\u001b[0m \u001b[38;5;66;03m# torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.size(2)}\"))\u001b[39;00m\n\u001b[1;32m 473\u001b[0m F0 \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n", + "Cell \u001b[0;32mIn[39], line 46\u001b[0m, in \u001b[0;36mStyleTTS2.forward\u001b[0;34m(self, tokens)\u001b[0m\n\u001b[1;32m 42\u001b[0m pred_aln_trg\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mvstack(pred_aln_trg_list)\n\u001b[1;32m 44\u001b[0m en \u001b[38;5;241m=\u001b[39m d\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m@\u001b[39m pred_aln_trg\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m---> 46\u001b[0m F0_pred, N_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpredictor\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mF0Ntrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43men\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m t_en \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext_encoder\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39minference(tokens)\n\u001b[1;32m 48\u001b[0m asr \u001b[38;5;241m=\u001b[39m t_en \u001b[38;5;241m@\u001b[39m pred_aln_trg\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n", + "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/models.py:471\u001b[0m, in \u001b[0;36mProsodyPredictor.F0Ntrain\u001b[0;34m(self, x, s)\u001b[0m\n\u001b[1;32m 466\u001b[0m x1 \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 467\u001b[0m \u001b[38;5;66;03m# torch._check(x1.dim() == 3, lambda: print(f\"Expected 3D tensor, got {x1.dim()}D tensor\"))\u001b[39;00m\n\u001b[1;32m 468\u001b[0m \u001b[38;5;66;03m# torch._check(x1.shape[1] > 0, lambda: print(f\"Shape 2, got {x1.shape[1]}\"))\u001b[39;00m\n\u001b[1;32m 469\u001b[0m \u001b[38;5;66;03m# torch._check(x1.shape[2] > 0, lambda: print(f\"Shape 2, got {x1.shape[2]}\"))\u001b[39;00m\n\u001b[1;32m 470\u001b[0m \u001b[38;5;66;03m# torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.shape[2]}\"))\u001b[39;00m\n\u001b[0;32m--> 471\u001b[0m x2, _temp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshared\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx1\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 472\u001b[0m \u001b[38;5;66;03m# torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.size(2)}\"))\u001b[39;00m\n\u001b[1;32m 474\u001b[0m F0 \u001b[38;5;241m=\u001b[39m x2\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123\u001b[0m, in \u001b[0;36mLSTM.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 1120\u001b[0m hx \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpermute_hidden(hx, sorted_indices)\n\u001b[1;32m 1122\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_sizes \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1123\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1124\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1125\u001b[0m \u001b[43m \u001b[49m\u001b[43mhx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1126\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_flat_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1127\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1128\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1129\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1130\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbidirectional\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1132\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_first\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1133\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1134\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1135\u001b[0m result \u001b[38;5;241m=\u001b[39m _VF\u001b[38;5;241m.\u001b[39mlstm(\n\u001b[1;32m 1136\u001b[0m \u001b[38;5;28minput\u001b[39m,\n\u001b[1;32m 1137\u001b[0m batch_sizes,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbidirectional,\n\u001b[1;32m 1145\u001b[0m )\n", @@ -283,11 +323,17 @@ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/recording.py:262\u001b[0m, in \u001b[0;36mrecord_shapeenv_event..decorator..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 256\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m args[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mis_recording: \u001b[38;5;66;03m# type: ignore[has-type]\u001b[39;00m\n\u001b[1;32m 257\u001b[0m \u001b[38;5;66;03m# If ShapeEnv is already recording an event, call the wrapped\u001b[39;00m\n\u001b[1;32m 258\u001b[0m \u001b[38;5;66;03m# function directly.\u001b[39;00m\n\u001b[1;32m 259\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m 260\u001b[0m \u001b[38;5;66;03m# NB: here, we skip the check of whether all ShapeEnv instances\u001b[39;00m\n\u001b[1;32m 261\u001b[0m \u001b[38;5;66;03m# are equal, in favor of a faster dispatch.\u001b[39;00m\n\u001b[0;32m--> 262\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m retlog(\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 264\u001b[0m \u001b[38;5;66;03m# Retrieve an instance of ShapeEnv.\u001b[39;00m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;66;03m# Assumption: the collection of args and kwargs may not reference\u001b[39;00m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;66;03m# different ShapeEnv instances.\u001b[39;00m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;28mself\u001b[39m \u001b[38;5;241m=\u001b[39m _extract_shape_env_and_assert_equal(args, kwargs)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5122\u001b[0m, in \u001b[0;36mShapeEnv.evaluate_expr\u001b[0;34m(self, orig_expr, hint, fx_node, size_oblivious, forcing_spec)\u001b[0m\n\u001b[1;32m 5117\u001b[0m \u001b[38;5;129m@lru_cache\u001b[39m(\u001b[38;5;241m256\u001b[39m)\n\u001b[1;32m 5118\u001b[0m \u001b[38;5;129m@record_shapeenv_event\u001b[39m(save_tracked_fakes\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 5119\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mevaluate_expr\u001b[39m(\u001b[38;5;28mself\u001b[39m, orig_expr: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msympy.Expr\u001b[39m\u001b[38;5;124m\"\u001b[39m, hint\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, fx_node\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 5120\u001b[0m size_oblivious: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;241m*\u001b[39m, forcing_spec: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 5121\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 5122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluate_expr\u001b[49m\u001b[43m(\u001b[49m\u001b[43morig_expr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfx_node\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize_oblivious\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mforcing_spec\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforcing_spec\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5123\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[1;32m 5124\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 5125\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfailed during evaluate_expr(\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m, hint=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m, size_oblivious=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m, forcing_spec=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 5126\u001b[0m orig_expr, hint, size_oblivious, forcing_spec\n\u001b[1;32m 5127\u001b[0m )\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5238\u001b[0m, in \u001b[0;36mShapeEnv._evaluate_expr\u001b[0;34m(self, orig_expr, hint, fx_node, size_oblivious, forcing_spec)\u001b[0m\n\u001b[1;32m 5236\u001b[0m concrete_val \u001b[38;5;241m=\u001b[39m unsound_result\n\u001b[1;32m 5237\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 5238\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_data_dependent_error(\n\u001b[1;32m 5239\u001b[0m expr\u001b[38;5;241m.\u001b[39mxreplace(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvar_to_val),\n\u001b[1;32m 5240\u001b[0m expr,\n\u001b[1;32m 5241\u001b[0m size_oblivious_result\u001b[38;5;241m=\u001b[39msize_oblivious_result\n\u001b[1;32m 5242\u001b[0m )\n\u001b[1;32m 5243\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 5244\u001b[0m expr \u001b[38;5;241m=\u001b[39m new_expr\n", - "\u001b[0;31mGuardOnDataDependentSymNode\u001b[0m: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)\n\nPotential framework code culprit (scroll up for full backtrace):\n File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n return self._op_dk(dk, *args, **kwargs)\n\nFor more information, run with TORCH_LOGS=\"dynamic\"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n File \"/rhome/eingerman/Projects/DeepLearning/TTS/Kokoro-82M/models.py\", line 470, in F0Ntrain\n torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.size(2)}\"))\n\nTo fix the error, insert one of the following checks before this call:\n 1. torch._check(x.shape[2])\n 2. torch._check(~x.shape[2])\n\n(These suggested fixes were derived by replacing `u0` with x.shape[2] or x1.shape[1] in u0 and its negation.)" + "\u001b[0;31mGuardOnDataDependentSymNode\u001b[0m: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)\n\nPotential framework code culprit (scroll up for full backtrace):\n File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n return self._op_dk(dk, *args, **kwargs)\n\nFor more information, run with TORCH_LOGS=\"dynamic\"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n File \"/rhome/eingerman/Projects/DeepLearning/TTS/Kokoro-82M/models.py\", line 471, in F0Ntrain\n x2, _temp = self.shared(x1)\n\nTo fix the error, insert one of the following checks before this call:\n 1. torch._check(x.shape[2])\n 2. torch._check(~x.shape[2])\n\n(These suggested fixes were derived by replacing `u0` with x.shape[2] or x1.shape[1] in u0 and its negation.)" ] } ], "source": [ + "os.environ['TORCH_LOGS'] = '+dynamic'\n", + "os.environ['TORCH_LOGS'] = '+export'\n", + "os.environ['TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED']=\"u0 >= 0\"\n", + "os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CPP']=\"1\"\n", + "os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL']=\"u0\"\n", + "\n", "class StyleTTS2(torch.nn.Module):\n", " def __init__(self, model, voicepack):\n", " super().__init__()\n", @@ -338,30 +384,136 @@ "(asr, F0_pred, N_pred, ref_s) = style_model(tokens)\n", "\n", "token_len = torch.export.Dim(\"token_len\", min=2, max=510)\n", - "dynamic_shapes = {\"tokens\":{1:token_len}}\n", + "batch = torch.export.Dim(\"batch\")\n", + "dynamic_shapes = {\"tokens\":{0:batch, 1:token_len}}\n", "\n", "# with torch.no_grad():\n", - "export_mod = torch.export.export(style_model, args=( tokens, ), dynamic_shapes=dynamic_shapes, strict=False)" + "export_mod = torch.export.export(style_model, args=( tokens, ), dynamic_shapes=dynamic_shapes, strict=False)\n", + "# export_mod = torch.export.export(style_model, args=( tokens, ), strict=False)" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 33, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "LSTM(640, 256, batch_first=True, bidirectional=True)" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" + "name": "stderr", + "output_type": "stream", + "text": [ + "I0101 18:19:15.402000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3557] create_symbol s0 = 143 for L['args'][0][0]._base.size()[1] [2, int_oo] (_export/non_strict_utils.py:109 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"s0\"\n", + "I0101 18:19:15.407000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3557] create_symbol s1 = 143 for L['args'][0][0].size()[0] [2, int_oo] (_export/non_strict_utils.py:109 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"s1\"\n", + "I0101 18:19:15.420000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:4857] set_replacement s1 = 143 (range_refined_to_singleton) VR[143, 143]\n", + "I0101 18:19:15.422000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5106] eval Eq(s1, 143) [guard added] (mp/ipykernel_2488298/2011460168.py:16 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"Eq(s1, 143)\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 143])\n", + "torch.Size([1, s1])\n", + "torch.Size([1, 143])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I0101 18:19:33.124000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3646] produce_guards\n" + ] + }, + { + "ename": "UserError", + "evalue": "Constraints violated (token_len)! For more information, run with TORCH_LOGS=\"+dynamic\".\n - Not all values of token_len = L['args'][0][0].size()[0] in the specified range are valid because token_len was inferred to be a constant (143).\nSuggested fixes:\n token_len = 143", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mConstraintViolationError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:670\u001b[0m, in \u001b[0;36m_export_to_aten_ir\u001b[0;34m(mod, fake_args, fake_kwargs, fake_params_buffers, constant_attrs, produce_guards_callback, transform, pre_dispatch, decomp_table, _check_autograd_state, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 669\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 670\u001b[0m \u001b[43mproduce_guards_callback\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgm\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 671\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (ConstraintViolationError, ValueRangeError) \u001b[38;5;28;01mas\u001b[39;00m e:\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1655\u001b[0m, in \u001b[0;36m_non_strict_export.._produce_guards_callback\u001b[0;34m(gm)\u001b[0m\n\u001b[1;32m 1654\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_produce_guards_callback\u001b[39m(gm):\n\u001b[0;32m-> 1655\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mproduce_guards_and_solve_constraints\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1656\u001b[0m \u001b[43m \u001b[49m\u001b[43mfake_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfake_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1657\u001b[0m \u001b[43m \u001b[49m\u001b[43mgm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1658\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtransformed_dynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1659\u001b[0m \u001b[43m \u001b[49m\u001b[43mequalities_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mequalities_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1660\u001b[0m \u001b[43m \u001b[49m\u001b[43moriginal_signature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moriginal_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1661\u001b[0m \u001b[43m \u001b[49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1662\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_export/non_strict_utils.py:305\u001b[0m, in \u001b[0;36mproduce_guards_and_solve_constraints\u001b[0;34m(fake_mode, gm, dynamic_shapes, equalities_inputs, original_signature, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m constraint_violation_error:\n\u001b[0;32m--> 305\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m constraint_violation_error\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_export/non_strict_utils.py:270\u001b[0m, in \u001b[0;36mproduce_guards_and_solve_constraints\u001b[0;34m(fake_mode, gm, dynamic_shapes, equalities_inputs, original_signature, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 270\u001b[0m \u001b[43mshape_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mproduce_guards\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mplaceholders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43msources\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_contexts\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_contexts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mequalities_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mequalities_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_static\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 276\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ConstraintViolationError \u001b[38;5;28;01mas\u001b[39;00m e:\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:4178\u001b[0m, in \u001b[0;36mShapeEnv.produce_guards\u001b[0;34m(self, placeholders, sources, source_ref, guards, input_contexts, equalities_inputs, _simplified, ignore_static)\u001b[0m\n\u001b[1;32m 4177\u001b[0m err \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)\n\u001b[0;32m-> 4178\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ConstraintViolationError(\n\u001b[1;32m 4179\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mConstraints violated (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdebug_names\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m)! \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4180\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFor more information, run with TORCH_LOGS=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m+dynamic\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 4181\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4182\u001b[0m )\n\u001b[1;32m 4183\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(warn_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", + "\u001b[0;31mConstraintViolationError\u001b[0m: Constraints violated (token_len)! For more information, run with TORCH_LOGS=\"+dynamic\".\n - Not all values of token_len = L['args'][0][0].size()[0] in the specified range are valid because token_len was inferred to be a constant (143).\nSuggested fixes:\n token_len = 143", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mUserError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[33], line 61\u001b[0m\n\u001b[1;32m 58\u001b[0m dynamic_shapes \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens0\u001b[39m\u001b[38;5;124m\"\u001b[39m:{\u001b[38;5;241m0\u001b[39m:token_len}}\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;66;03m# export_mod = torch.export.export(test_model, args=( tokens[0,:], ), strict=False).run_decompositions()\u001b[39;00m\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28mprint\u001b[39m(export_mod)\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/__init__.py:270\u001b[0m, in \u001b[0;36mexport\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature)\u001b[0m\n\u001b[1;32m 264\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(mod, torch\u001b[38;5;241m.\u001b[39mjit\u001b[38;5;241m.\u001b[39mScriptModule):\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExporting a ScriptModule is not supported. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMaybe try converting your ScriptModule to an ExportedProgram \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124musing `TS2EPConverter(mod, args, kwargs).convert()` instead.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 269\u001b[0m )\n\u001b[0;32m--> 270\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_export\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstrict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 276\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 278\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1017\u001b[0m, in \u001b[0;36m_log_export_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1010\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1011\u001b[0m log_export_usage(\n\u001b[1;32m 1012\u001b[0m event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.error.unclassified\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1013\u001b[0m \u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39merror_type,\n\u001b[1;32m 1014\u001b[0m message\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m 1015\u001b[0m flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m 1016\u001b[0m )\n\u001b[0;32m-> 1017\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 1018\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 1019\u001b[0m _EXPORT_FLAGS \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:990\u001b[0m, in \u001b[0;36m_log_export_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 989\u001b[0m start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 990\u001b[0m ep \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 991\u001b[0m end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 992\u001b[0m log_export_usage(\n\u001b[1;32m 993\u001b[0m event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.time\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 994\u001b[0m metrics\u001b[38;5;241m=\u001b[39mend \u001b[38;5;241m-\u001b[39m start,\n\u001b[1;32m 995\u001b[0m flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m 996\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mget_ep_stats(ep),\n\u001b[1;32m 997\u001b[0m )\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/exported_program.py:114\u001b[0m, in \u001b[0;36m_disable_prexisiting_fake_mode..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(fn)\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m unset_fake_temporarily():\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1880\u001b[0m, in \u001b[0;36m_export\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature, pre_dispatch, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 1877\u001b[0m \u001b[38;5;66;03m# Call the appropriate export function based on the strictness of tracing.\u001b[39;00m\n\u001b[1;32m 1878\u001b[0m export_func \u001b[38;5;241m=\u001b[39m _strict_export \u001b[38;5;28;01mif\u001b[39;00m strict \u001b[38;5;28;01melse\u001b[39;00m _non_strict_export\n\u001b[0;32m-> 1880\u001b[0m export_artifact \u001b[38;5;241m=\u001b[39m \u001b[43mexport_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# type: ignore[operator]\u001b[39;49;00m\n\u001b[1;32m 1881\u001b[0m \u001b[43m \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1882\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1883\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1884\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1885\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1886\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1887\u001b[0m \u001b[43m \u001b[49m\u001b[43moriginal_state_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1888\u001b[0m \u001b[43m \u001b[49m\u001b[43moriginal_in_spec\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1889\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_complex_guards_as_runtime_asserts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1890\u001b[0m \u001b[43m \u001b[49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1891\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1892\u001b[0m export_graph_signature: ExportGraphSignature \u001b[38;5;241m=\u001b[39m export_artifact\u001b[38;5;241m.\u001b[39maten\u001b[38;5;241m.\u001b[39msig\n\u001b[1;32m 1894\u001b[0m forward_arg_names \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1895\u001b[0m _get_forward_arg_names(mod, args, kwargs) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_torch_jit_trace \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1896\u001b[0m )\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1683\u001b[0m, in \u001b[0;36m_non_strict_export\u001b[0;34m(mod, args, kwargs, dynamic_shapes, preserve_module_call_signature, pre_dispatch, original_state_dict, orig_in_spec, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace, dispatch_tracing_mode)\u001b[0m\n\u001b[1;32m 1667\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) \u001b[38;5;28;01mas\u001b[39;00m (\n\u001b[1;32m 1668\u001b[0m patched_mod,\n\u001b[1;32m 1669\u001b[0m new_fake_args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1672\u001b[0m map_fake_to_real,\n\u001b[1;32m 1673\u001b[0m ):\n\u001b[1;32m 1674\u001b[0m _to_aten_func \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1675\u001b[0m _export_to_aten_ir_make_fx\n\u001b[1;32m 1676\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dispatch_tracing_mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmake_fx\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1681\u001b[0m )\n\u001b[1;32m 1682\u001b[0m )\n\u001b[0;32m-> 1683\u001b[0m aten_export_artifact \u001b[38;5;241m=\u001b[39m \u001b[43m_to_aten_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# type: ignore[operator]\u001b[39;49;00m\n\u001b[1;32m 1684\u001b[0m \u001b[43m \u001b[49m\u001b[43mpatched_mod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1685\u001b[0m \u001b[43m \u001b[49m\u001b[43mnew_fake_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1686\u001b[0m \u001b[43m \u001b[49m\u001b[43mnew_fake_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1687\u001b[0m \u001b[43m \u001b[49m\u001b[43mfake_params_buffers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1688\u001b[0m \u001b[43m \u001b[49m\u001b[43mnew_fake_constant_attrs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1689\u001b[0m \u001b[43m \u001b[49m\u001b[43mproduce_guards_callback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_produce_guards_callback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1690\u001b[0m \u001b[43m \u001b[49m\u001b[43mtransform\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_tuplify_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1691\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1692\u001b[0m \u001b[38;5;66;03m# aten_export_artifact.constants contains only fake script objects, we need to map them back\u001b[39;00m\n\u001b[1;32m 1693\u001b[0m aten_export_artifact\u001b[38;5;241m.\u001b[39mconstants \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 1694\u001b[0m fqn: map_fake_to_real[obj] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, FakeScriptObject) \u001b[38;5;28;01melse\u001b[39;00m obj\n\u001b[1;32m 1695\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m fqn, obj \u001b[38;5;129;01min\u001b[39;00m aten_export_artifact\u001b[38;5;241m.\u001b[39mconstants\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 1696\u001b[0m }\n", + "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:672\u001b[0m, in \u001b[0;36m_export_to_aten_ir\u001b[0;34m(mod, fake_args, fake_kwargs, fake_params_buffers, constant_attrs, produce_guards_callback, transform, pre_dispatch, decomp_table, _check_autograd_state, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 670\u001b[0m produce_guards_callback(gm)\n\u001b[1;32m 671\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (ConstraintViolationError, ValueRangeError) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m--> 672\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m UserError(UserErrorType\u001b[38;5;241m.\u001b[39mCONSTRAINT_VIOLATION, \u001b[38;5;28mstr\u001b[39m(e)) \u001b[38;5;66;03m# noqa: B904\u001b[39;00m\n\u001b[1;32m 674\u001b[0m \u001b[38;5;66;03m# Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature.\u001b[39;00m\n\u001b[1;32m 675\u001b[0m \u001b[38;5;66;03m# Overwrite output specs afterwards.\u001b[39;00m\n\u001b[1;32m 676\u001b[0m flat_fake_args \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_leaves((fake_args, fake_kwargs))\n", + "\u001b[0;31mUserError\u001b[0m: Constraints violated (token_len)! For more information, run with TORCH_LOGS=\"+dynamic\".\n - Not all values of token_len = L['args'][0][0].size()[0] in the specified range are valid because token_len was inferred to be a constant (143).\nSuggested fixes:\n token_len = 143" + ] } ], "source": [ - "model[\"predictor\"].shared" + "os.environ['TORCH_LOGS'] = '+dynamic'\n", + "os.environ['TORCH_LOGS'] = '+export'\n", + "class test(torch.nn.Module):\n", + " def __init__(self, model, voicepack):\n", + " super().__init__()\n", + " self.model = model\n", + " self.voicepack = voicepack\n", + " self.model.text_encoder.lstm.flatten_parameters()\n", + " \n", + " def forward(self, tokens0):\n", + " tokens = tokens0.unsqueeze(0)\n", + " print(tokens.shape)\n", + " # speed = 1.\n", + " # # tokens = torch.nn.functional.pad(tokens, (0, 510 - tokens.shape[-1]))\n", + " # device = tokens.device\n", + " input_lengths = torch.LongTensor([tokens0.shape[-1]]).to(device)\n", + "\n", + " # text_mask = length_to_mask(input_lengths).to(device)\n", + " # bert_dur = self.model['bert'](tokens, attention_mask=(~text_mask).int())\n", + "\n", + " # d_en = self.model[\"bert_encoder\"](bert_dur).transpose(-1, -2)\n", + "\n", + " # ref_s = self.voicepack[tokens.shape[1]]\n", + " # s = ref_s[:, 128:]\n", + "\n", + " # d = self.model[\"predictor\"].text_encoder.inference(d_en, s)\n", + " # x, _ = self.model[\"predictor\"].lstm(d)\n", + "\n", + " # duration = self.model[\"predictor\"].duration_proj(x)\n", + " # duration = torch.sigmoid(duration).sum(axis=-1) / speed\n", + " # pred_dur = torch.round(duration).clamp(min=1).long()\n", + " \n", + " # c_start = F.pad(pred_dur,(1,0), \"constant\").cumsum(dim=1)[0,0:-1]\n", + " # c_end = c_start + pred_dur[0,:]\n", + " # indices = torch.arange(0, pred_dur.sum().item()).long().to(device)\n", + "\n", + " # pred_aln_trg_list=[]\n", + " # for cs, ce in zip(c_start, c_end):\n", + " # row = torch.where((indices>=cs) & (indices