diff --git "a/test.ipynb" "b/test.ipynb"
--- "a/test.ipynb"
+++ "b/test.ipynb"
@@ -8,6 +8,9 @@
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
+ "import os\n",
+ "\n",
+ "os.environ['TORCH_LOGS'] = 'dynamic'\n",
"\n",
"import pylab as pl"
]
@@ -33,7 +36,7 @@
"text/html": [
"\n",
" \n",
" "
@@ -73,7 +76,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -88,7 +91,7 @@
"text/html": [
"\n",
" \n",
" "
@@ -160,7 +163,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 11,
"metadata": {},
"outputs": [
{
@@ -171,7 +174,10 @@
"torch.Size([1, 640, 143])\n",
"torch.Size([1, 512, 143])\n",
"torch.Size([1, 143, 348])\n",
- "torch.Size([1, 640, 348])\n",
+ "en.shape=torch.Size([1, 640, 348])\n",
+ "s.shape=torch.Size([1, 128])\n",
+ "en.dtype=torch.float32\n",
+ "s.dtype=torch.float32\n",
"torch.Size([1, 512, 143])\n",
"torch.Size([1, 512, 348])\n"
]
@@ -179,10 +185,10 @@
{
"data": {
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 4,
+ "execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
@@ -202,7 +208,11 @@
"print(d.transpose(-1,-2).shape)\n",
"print(d_en.shape)\n",
"print(pred_aln_trg.unsqueeze(0).shape)\n",
- "print(en.shape)\n",
+ "print(f\"{en.shape=}\")\n",
+ "print(f\"{s.shape=}\")\n",
+ "print(f\"{en.dtype=}\")\n",
+ "print(f\"{s.dtype=}\")\n",
+ "\n",
"print(t_en.shape)\n",
"print(asr.shape)\n",
"pl.imshow(pred_aln_trg[:,:])\n"
@@ -217,25 +227,55 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "W1231 22:39:38.288000 1824615 site-packages/torch/fx/experimental/symbolic_shapes.py:5124] failed during evaluate_expr(u0, hint=None, size_oblivious=False, forcing_spec=False\n",
- "E1231 22:39:38.291000 1824615 site-packages/torch/fx/experimental/recording.py:298] failed while running evaluate_expr(*(u0, None), **{'fx_node': False})\n"
+ "I0101 18:56:11.998000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3557] create_symbol s0 = 143 for L['args'][0][0].size()[1] [2, 510] (_export/non_strict_utils.py:109 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"s0\"\n",
+ "I0101 18:56:12.007000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:4857] set_replacement s0 = 143 (range_refined_to_singleton) VR[143, 143]\n",
+ "I0101 18:56:12.008000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5106] eval Eq(s0, 143) [guard added] (mp/ipykernel_2488298/2554868606.py:17 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"Eq(s0, 143)\"\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "I0101 18:56:27.383000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3317] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:390 in local_scalar_dense)\n",
+ "I0101 18:56:27.387000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5106] runtime_assert u0 >= 0 [guard added] (_refs/__init__.py:4957 in arange), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"u0 >= 0\"\n",
+ "W0101 18:56:28.575000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5124] failed during evaluate_expr(u0, hint=None, size_oblivious=False, forcing_spec=False\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] failed while running evaluate_expr(*(u0, None), **{'fx_node': False})\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] Traceback (most recent call last):\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/recording.py\", line 262, in wrapper\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] return retlog(fn(*args, **kwargs))\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py\", line 5122, in evaluate_expr\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py\", line 5238, in _evaluate_expr\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] raise self._make_data_dependent_error(\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] \n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] Potential framework code culprit (scroll up for full backtrace):\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] return self._op_dk(dk, *args, **kwargs)\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] \n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For more information, run with TORCH_LOGS=\"dynamic\"\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] \n",
+ "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n"
]
},
{
"ename": "GuardOnDataDependentSymNode",
- "evalue": "Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)\n\nPotential framework code culprit (scroll up for full backtrace):\n File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n return self._op_dk(dk, *args, **kwargs)\n\nFor more information, run with TORCH_LOGS=\"dynamic\"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n File \"/rhome/eingerman/Projects/DeepLearning/TTS/Kokoro-82M/models.py\", line 470, in F0Ntrain\n torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.size(2)}\"))\n\nTo fix the error, insert one of the following checks before this call:\n 1. torch._check(x.shape[2])\n 2. torch._check(~x.shape[2])\n\n(These suggested fixes were derived by replacing `u0` with x.shape[2] or x1.shape[1] in u0 and its negation.)",
+ "evalue": "Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)\n\nPotential framework code culprit (scroll up for full backtrace):\n File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n return self._op_dk(dk, *args, **kwargs)\n\nFor more information, run with TORCH_LOGS=\"dynamic\"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n File \"/rhome/eingerman/Projects/DeepLearning/TTS/Kokoro-82M/models.py\", line 471, in F0Ntrain\n x2, _temp = self.shared(x1)\n\nTo fix the error, insert one of the following checks before this call:\n 1. torch._check(x.shape[2])\n 2. torch._check(~x.shape[2])\n\n(These suggested fixes were derived by replacing `u0` with x.shape[2] or x1.shape[1] in u0 and its negation.)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mGuardOnDataDependentSymNode\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[23], line 54\u001b[0m\n\u001b[1;32m 51\u001b[0m dynamic_shapes \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens\u001b[39m\u001b[38;5;124m\"\u001b[39m:{\u001b[38;5;241m1\u001b[39m:token_len}}\n\u001b[1;32m 53\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[0;32m---> 54\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstyle_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
+ "Cell \u001b[0;32mIn[39], line 61\u001b[0m\n\u001b[1;32m 58\u001b[0m dynamic_shapes \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens\u001b[39m\u001b[38;5;124m\"\u001b[39m:{\u001b[38;5;241m0\u001b[39m:batch, \u001b[38;5;241m1\u001b[39m:token_len}}\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstyle_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;66;03m# export_mod = torch.export.export(style_model, args=( tokens, ), strict=False)\u001b[39;00m\n",
"File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/__init__.py:270\u001b[0m, in \u001b[0;36mexport\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature)\u001b[0m\n\u001b[1;32m 264\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(mod, torch\u001b[38;5;241m.\u001b[39mjit\u001b[38;5;241m.\u001b[39mScriptModule):\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExporting a ScriptModule is not supported. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMaybe try converting your ScriptModule to an ExportedProgram \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124musing `TS2EPConverter(mod, args, kwargs).convert()` instead.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 269\u001b[0m )\n\u001b[0;32m--> 270\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_export\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstrict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 276\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 278\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1017\u001b[0m, in \u001b[0;36m_log_export_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1010\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1011\u001b[0m log_export_usage(\n\u001b[1;32m 1012\u001b[0m event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.error.unclassified\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1013\u001b[0m \u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39merror_type,\n\u001b[1;32m 1014\u001b[0m message\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m 1015\u001b[0m flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m 1016\u001b[0m )\n\u001b[0;32m-> 1017\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 1018\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 1019\u001b[0m _EXPORT_FLAGS \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:990\u001b[0m, in \u001b[0;36m_log_export_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 989\u001b[0m start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 990\u001b[0m ep \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 991\u001b[0m end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 992\u001b[0m log_export_usage(\n\u001b[1;32m 993\u001b[0m event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.time\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 994\u001b[0m metrics\u001b[38;5;241m=\u001b[39mend \u001b[38;5;241m-\u001b[39m start,\n\u001b[1;32m 995\u001b[0m flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m 996\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mget_ep_stats(ep),\n\u001b[1;32m 997\u001b[0m )\n",
@@ -256,8 +296,8 @@
"File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1598\u001b[0m, in \u001b[0;36m_non_strict_export.._tuplify_outputs.._aot_export_non_strict..Wrapper.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1594\u001b[0m tree_out \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mfx\u001b[38;5;241m.\u001b[39mInterpreter(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_export_root)\u001b[38;5;241m.\u001b[39mrun(\n\u001b[1;32m 1595\u001b[0m \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m 1596\u001b[0m )\n\u001b[1;32m 1597\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1598\u001b[0m tree_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_export_root\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1599\u001b[0m flat_outs, out_spec \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_flatten(tree_out)\n\u001b[1;32m 1600\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtuple\u001b[39m(flat_outs)\n",
"File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
- "Cell \u001b[0;32mIn[23], line 40\u001b[0m, in \u001b[0;36mStyleTTS2.forward\u001b[0;34m(self, tokens)\u001b[0m\n\u001b[1;32m 36\u001b[0m pred_aln_trg\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mvstack(pred_aln_trg_list)\n\u001b[1;32m 38\u001b[0m en \u001b[38;5;241m=\u001b[39m d\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m@\u001b[39m pred_aln_trg\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m---> 40\u001b[0m F0_pred, N_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpredictor\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mF0Ntrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43men\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 41\u001b[0m t_en \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext_encoder\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39minference(tokens)\n\u001b[1;32m 42\u001b[0m asr \u001b[38;5;241m=\u001b[39m t_en \u001b[38;5;241m@\u001b[39m pred_aln_trg\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n",
- "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/models.py:470\u001b[0m, in \u001b[0;36mProsodyPredictor.F0Ntrain\u001b[0;34m(self, x, s)\u001b[0m\n\u001b[1;32m 468\u001b[0m torch\u001b[38;5;241m.\u001b[39m_check(x1\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m, \u001b[38;5;28;01mlambda\u001b[39;00m: \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mShape 2, got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mx1\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m2\u001b[39m)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 469\u001b[0m torch\u001b[38;5;241m.\u001b[39m_check(x\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m, \u001b[38;5;28;01mlambda\u001b[39;00m: \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mShape 2, got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mx\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m2\u001b[39m)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m))\n\u001b[0;32m--> 470\u001b[0m x, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshared\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx1\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 471\u001b[0m \u001b[38;5;66;03m# torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.size(2)}\"))\u001b[39;00m\n\u001b[1;32m 473\u001b[0m F0 \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n",
+ "Cell \u001b[0;32mIn[39], line 46\u001b[0m, in \u001b[0;36mStyleTTS2.forward\u001b[0;34m(self, tokens)\u001b[0m\n\u001b[1;32m 42\u001b[0m pred_aln_trg\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mvstack(pred_aln_trg_list)\n\u001b[1;32m 44\u001b[0m en \u001b[38;5;241m=\u001b[39m d\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m@\u001b[39m pred_aln_trg\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m---> 46\u001b[0m F0_pred, N_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpredictor\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mF0Ntrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43men\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m t_en \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext_encoder\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39minference(tokens)\n\u001b[1;32m 48\u001b[0m asr \u001b[38;5;241m=\u001b[39m t_en \u001b[38;5;241m@\u001b[39m pred_aln_trg\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n",
+ "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/models.py:471\u001b[0m, in \u001b[0;36mProsodyPredictor.F0Ntrain\u001b[0;34m(self, x, s)\u001b[0m\n\u001b[1;32m 466\u001b[0m x1 \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 467\u001b[0m \u001b[38;5;66;03m# torch._check(x1.dim() == 3, lambda: print(f\"Expected 3D tensor, got {x1.dim()}D tensor\"))\u001b[39;00m\n\u001b[1;32m 468\u001b[0m \u001b[38;5;66;03m# torch._check(x1.shape[1] > 0, lambda: print(f\"Shape 2, got {x1.shape[1]}\"))\u001b[39;00m\n\u001b[1;32m 469\u001b[0m \u001b[38;5;66;03m# torch._check(x1.shape[2] > 0, lambda: print(f\"Shape 2, got {x1.shape[2]}\"))\u001b[39;00m\n\u001b[1;32m 470\u001b[0m \u001b[38;5;66;03m# torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.shape[2]}\"))\u001b[39;00m\n\u001b[0;32m--> 471\u001b[0m x2, _temp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshared\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx1\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 472\u001b[0m \u001b[38;5;66;03m# torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.size(2)}\"))\u001b[39;00m\n\u001b[1;32m 474\u001b[0m F0 \u001b[38;5;241m=\u001b[39m x2\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n",
"File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
"File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123\u001b[0m, in \u001b[0;36mLSTM.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 1120\u001b[0m hx \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpermute_hidden(hx, sorted_indices)\n\u001b[1;32m 1122\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_sizes \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1123\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1124\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1125\u001b[0m \u001b[43m \u001b[49m\u001b[43mhx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1126\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_flat_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1127\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1128\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1129\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1130\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbidirectional\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1132\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_first\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1133\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1134\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1135\u001b[0m result \u001b[38;5;241m=\u001b[39m _VF\u001b[38;5;241m.\u001b[39mlstm(\n\u001b[1;32m 1136\u001b[0m \u001b[38;5;28minput\u001b[39m,\n\u001b[1;32m 1137\u001b[0m batch_sizes,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbidirectional,\n\u001b[1;32m 1145\u001b[0m )\n",
@@ -283,11 +323,17 @@
"File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/recording.py:262\u001b[0m, in \u001b[0;36mrecord_shapeenv_event..decorator..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 256\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m args[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mis_recording: \u001b[38;5;66;03m# type: ignore[has-type]\u001b[39;00m\n\u001b[1;32m 257\u001b[0m \u001b[38;5;66;03m# If ShapeEnv is already recording an event, call the wrapped\u001b[39;00m\n\u001b[1;32m 258\u001b[0m \u001b[38;5;66;03m# function directly.\u001b[39;00m\n\u001b[1;32m 259\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m 260\u001b[0m \u001b[38;5;66;03m# NB: here, we skip the check of whether all ShapeEnv instances\u001b[39;00m\n\u001b[1;32m 261\u001b[0m \u001b[38;5;66;03m# are equal, in favor of a faster dispatch.\u001b[39;00m\n\u001b[0;32m--> 262\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m retlog(\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 264\u001b[0m \u001b[38;5;66;03m# Retrieve an instance of ShapeEnv.\u001b[39;00m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;66;03m# Assumption: the collection of args and kwargs may not reference\u001b[39;00m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;66;03m# different ShapeEnv instances.\u001b[39;00m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;28mself\u001b[39m \u001b[38;5;241m=\u001b[39m _extract_shape_env_and_assert_equal(args, kwargs)\n",
"File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5122\u001b[0m, in \u001b[0;36mShapeEnv.evaluate_expr\u001b[0;34m(self, orig_expr, hint, fx_node, size_oblivious, forcing_spec)\u001b[0m\n\u001b[1;32m 5117\u001b[0m \u001b[38;5;129m@lru_cache\u001b[39m(\u001b[38;5;241m256\u001b[39m)\n\u001b[1;32m 5118\u001b[0m \u001b[38;5;129m@record_shapeenv_event\u001b[39m(save_tracked_fakes\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 5119\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mevaluate_expr\u001b[39m(\u001b[38;5;28mself\u001b[39m, orig_expr: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msympy.Expr\u001b[39m\u001b[38;5;124m\"\u001b[39m, hint\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, fx_node\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 5120\u001b[0m size_oblivious: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;241m*\u001b[39m, forcing_spec: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 5121\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 5122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluate_expr\u001b[49m\u001b[43m(\u001b[49m\u001b[43morig_expr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfx_node\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize_oblivious\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mforcing_spec\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforcing_spec\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5123\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[1;32m 5124\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 5125\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfailed during evaluate_expr(\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m, hint=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m, size_oblivious=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m, forcing_spec=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 5126\u001b[0m orig_expr, hint, size_oblivious, forcing_spec\n\u001b[1;32m 5127\u001b[0m )\n",
"File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5238\u001b[0m, in \u001b[0;36mShapeEnv._evaluate_expr\u001b[0;34m(self, orig_expr, hint, fx_node, size_oblivious, forcing_spec)\u001b[0m\n\u001b[1;32m 5236\u001b[0m concrete_val \u001b[38;5;241m=\u001b[39m unsound_result\n\u001b[1;32m 5237\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 5238\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_data_dependent_error(\n\u001b[1;32m 5239\u001b[0m expr\u001b[38;5;241m.\u001b[39mxreplace(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvar_to_val),\n\u001b[1;32m 5240\u001b[0m expr,\n\u001b[1;32m 5241\u001b[0m size_oblivious_result\u001b[38;5;241m=\u001b[39msize_oblivious_result\n\u001b[1;32m 5242\u001b[0m )\n\u001b[1;32m 5243\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 5244\u001b[0m expr \u001b[38;5;241m=\u001b[39m new_expr\n",
- "\u001b[0;31mGuardOnDataDependentSymNode\u001b[0m: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)\n\nPotential framework code culprit (scroll up for full backtrace):\n File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n return self._op_dk(dk, *args, **kwargs)\n\nFor more information, run with TORCH_LOGS=\"dynamic\"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n File \"/rhome/eingerman/Projects/DeepLearning/TTS/Kokoro-82M/models.py\", line 470, in F0Ntrain\n torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.size(2)}\"))\n\nTo fix the error, insert one of the following checks before this call:\n 1. torch._check(x.shape[2])\n 2. torch._check(~x.shape[2])\n\n(These suggested fixes were derived by replacing `u0` with x.shape[2] or x1.shape[1] in u0 and its negation.)"
+ "\u001b[0;31mGuardOnDataDependentSymNode\u001b[0m: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)\n\nPotential framework code culprit (scroll up for full backtrace):\n File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n return self._op_dk(dk, *args, **kwargs)\n\nFor more information, run with TORCH_LOGS=\"dynamic\"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n File \"/rhome/eingerman/Projects/DeepLearning/TTS/Kokoro-82M/models.py\", line 471, in F0Ntrain\n x2, _temp = self.shared(x1)\n\nTo fix the error, insert one of the following checks before this call:\n 1. torch._check(x.shape[2])\n 2. torch._check(~x.shape[2])\n\n(These suggested fixes were derived by replacing `u0` with x.shape[2] or x1.shape[1] in u0 and its negation.)"
]
}
],
"source": [
+ "os.environ['TORCH_LOGS'] = '+dynamic'\n",
+ "os.environ['TORCH_LOGS'] = '+export'\n",
+ "os.environ['TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED']=\"u0 >= 0\"\n",
+ "os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CPP']=\"1\"\n",
+ "os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL']=\"u0\"\n",
+ "\n",
"class StyleTTS2(torch.nn.Module):\n",
" def __init__(self, model, voicepack):\n",
" super().__init__()\n",
@@ -338,30 +384,136 @@
"(asr, F0_pred, N_pred, ref_s) = style_model(tokens)\n",
"\n",
"token_len = torch.export.Dim(\"token_len\", min=2, max=510)\n",
- "dynamic_shapes = {\"tokens\":{1:token_len}}\n",
+ "batch = torch.export.Dim(\"batch\")\n",
+ "dynamic_shapes = {\"tokens\":{0:batch, 1:token_len}}\n",
"\n",
"# with torch.no_grad():\n",
- "export_mod = torch.export.export(style_model, args=( tokens, ), dynamic_shapes=dynamic_shapes, strict=False)"
+ "export_mod = torch.export.export(style_model, args=( tokens, ), dynamic_shapes=dynamic_shapes, strict=False)\n",
+ "# export_mod = torch.export.export(style_model, args=( tokens, ), strict=False)"
]
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 33,
"metadata": {},
"outputs": [
{
- "data": {
- "text/plain": [
- "LSTM(640, 256, batch_first=True, bidirectional=True)"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "I0101 18:19:15.402000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3557] create_symbol s0 = 143 for L['args'][0][0]._base.size()[1] [2, int_oo] (_export/non_strict_utils.py:109 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"s0\"\n",
+ "I0101 18:19:15.407000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3557] create_symbol s1 = 143 for L['args'][0][0].size()[0] [2, int_oo] (_export/non_strict_utils.py:109 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"s1\"\n",
+ "I0101 18:19:15.420000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:4857] set_replacement s1 = 143 (range_refined_to_singleton) VR[143, 143]\n",
+ "I0101 18:19:15.422000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5106] eval Eq(s1, 143) [guard added] (mp/ipykernel_2488298/2011460168.py:16 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"Eq(s1, 143)\"\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([1, 143])\n",
+ "torch.Size([1, s1])\n",
+ "torch.Size([1, 143])\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "I0101 18:19:33.124000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3646] produce_guards\n"
+ ]
+ },
+ {
+ "ename": "UserError",
+ "evalue": "Constraints violated (token_len)! For more information, run with TORCH_LOGS=\"+dynamic\".\n - Not all values of token_len = L['args'][0][0].size()[0] in the specified range are valid because token_len was inferred to be a constant (143).\nSuggested fixes:\n token_len = 143",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mConstraintViolationError\u001b[0m Traceback (most recent call last)",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:670\u001b[0m, in \u001b[0;36m_export_to_aten_ir\u001b[0;34m(mod, fake_args, fake_kwargs, fake_params_buffers, constant_attrs, produce_guards_callback, transform, pre_dispatch, decomp_table, _check_autograd_state, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 669\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 670\u001b[0m \u001b[43mproduce_guards_callback\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgm\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 671\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (ConstraintViolationError, ValueRangeError) \u001b[38;5;28;01mas\u001b[39;00m e:\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1655\u001b[0m, in \u001b[0;36m_non_strict_export.._produce_guards_callback\u001b[0;34m(gm)\u001b[0m\n\u001b[1;32m 1654\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_produce_guards_callback\u001b[39m(gm):\n\u001b[0;32m-> 1655\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mproduce_guards_and_solve_constraints\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1656\u001b[0m \u001b[43m \u001b[49m\u001b[43mfake_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfake_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1657\u001b[0m \u001b[43m \u001b[49m\u001b[43mgm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1658\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtransformed_dynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1659\u001b[0m \u001b[43m \u001b[49m\u001b[43mequalities_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mequalities_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1660\u001b[0m \u001b[43m \u001b[49m\u001b[43moriginal_signature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moriginal_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1661\u001b[0m \u001b[43m \u001b[49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1662\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_export/non_strict_utils.py:305\u001b[0m, in \u001b[0;36mproduce_guards_and_solve_constraints\u001b[0;34m(fake_mode, gm, dynamic_shapes, equalities_inputs, original_signature, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m constraint_violation_error:\n\u001b[0;32m--> 305\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m constraint_violation_error\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_export/non_strict_utils.py:270\u001b[0m, in \u001b[0;36mproduce_guards_and_solve_constraints\u001b[0;34m(fake_mode, gm, dynamic_shapes, equalities_inputs, original_signature, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 270\u001b[0m \u001b[43mshape_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mproduce_guards\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mplaceholders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43msources\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_contexts\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_contexts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mequalities_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mequalities_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_static\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 276\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ConstraintViolationError \u001b[38;5;28;01mas\u001b[39;00m e:\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:4178\u001b[0m, in \u001b[0;36mShapeEnv.produce_guards\u001b[0;34m(self, placeholders, sources, source_ref, guards, input_contexts, equalities_inputs, _simplified, ignore_static)\u001b[0m\n\u001b[1;32m 4177\u001b[0m err \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)\n\u001b[0;32m-> 4178\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ConstraintViolationError(\n\u001b[1;32m 4179\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mConstraints violated (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdebug_names\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m)! \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4180\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFor more information, run with TORCH_LOGS=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m+dynamic\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 4181\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4182\u001b[0m )\n\u001b[1;32m 4183\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(warn_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
+ "\u001b[0;31mConstraintViolationError\u001b[0m: Constraints violated (token_len)! For more information, run with TORCH_LOGS=\"+dynamic\".\n - Not all values of token_len = L['args'][0][0].size()[0] in the specified range are valid because token_len was inferred to be a constant (143).\nSuggested fixes:\n token_len = 143",
+ "\nDuring handling of the above exception, another exception occurred:\n",
+ "\u001b[0;31mUserError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[33], line 61\u001b[0m\n\u001b[1;32m 58\u001b[0m dynamic_shapes \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens0\u001b[39m\u001b[38;5;124m\"\u001b[39m:{\u001b[38;5;241m0\u001b[39m:token_len}}\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;66;03m# export_mod = torch.export.export(test_model, args=( tokens[0,:], ), strict=False).run_decompositions()\u001b[39;00m\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28mprint\u001b[39m(export_mod)\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/__init__.py:270\u001b[0m, in \u001b[0;36mexport\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature)\u001b[0m\n\u001b[1;32m 264\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(mod, torch\u001b[38;5;241m.\u001b[39mjit\u001b[38;5;241m.\u001b[39mScriptModule):\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExporting a ScriptModule is not supported. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMaybe try converting your ScriptModule to an ExportedProgram \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124musing `TS2EPConverter(mod, args, kwargs).convert()` instead.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 269\u001b[0m )\n\u001b[0;32m--> 270\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_export\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstrict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 276\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 278\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1017\u001b[0m, in \u001b[0;36m_log_export_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1010\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1011\u001b[0m log_export_usage(\n\u001b[1;32m 1012\u001b[0m event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.error.unclassified\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1013\u001b[0m \u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39merror_type,\n\u001b[1;32m 1014\u001b[0m message\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m 1015\u001b[0m flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m 1016\u001b[0m )\n\u001b[0;32m-> 1017\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 1018\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 1019\u001b[0m _EXPORT_FLAGS \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:990\u001b[0m, in \u001b[0;36m_log_export_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 989\u001b[0m start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 990\u001b[0m ep \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 991\u001b[0m end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 992\u001b[0m log_export_usage(\n\u001b[1;32m 993\u001b[0m event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.time\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 994\u001b[0m metrics\u001b[38;5;241m=\u001b[39mend \u001b[38;5;241m-\u001b[39m start,\n\u001b[1;32m 995\u001b[0m flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m 996\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mget_ep_stats(ep),\n\u001b[1;32m 997\u001b[0m )\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/exported_program.py:114\u001b[0m, in \u001b[0;36m_disable_prexisiting_fake_mode..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(fn)\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m unset_fake_temporarily():\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1880\u001b[0m, in \u001b[0;36m_export\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature, pre_dispatch, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 1877\u001b[0m \u001b[38;5;66;03m# Call the appropriate export function based on the strictness of tracing.\u001b[39;00m\n\u001b[1;32m 1878\u001b[0m export_func \u001b[38;5;241m=\u001b[39m _strict_export \u001b[38;5;28;01mif\u001b[39;00m strict \u001b[38;5;28;01melse\u001b[39;00m _non_strict_export\n\u001b[0;32m-> 1880\u001b[0m export_artifact \u001b[38;5;241m=\u001b[39m \u001b[43mexport_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# type: ignore[operator]\u001b[39;49;00m\n\u001b[1;32m 1881\u001b[0m \u001b[43m \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1882\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1883\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1884\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1885\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1886\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1887\u001b[0m \u001b[43m \u001b[49m\u001b[43moriginal_state_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1888\u001b[0m \u001b[43m \u001b[49m\u001b[43moriginal_in_spec\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1889\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_complex_guards_as_runtime_asserts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1890\u001b[0m \u001b[43m \u001b[49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1891\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1892\u001b[0m export_graph_signature: ExportGraphSignature \u001b[38;5;241m=\u001b[39m export_artifact\u001b[38;5;241m.\u001b[39maten\u001b[38;5;241m.\u001b[39msig\n\u001b[1;32m 1894\u001b[0m forward_arg_names \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1895\u001b[0m _get_forward_arg_names(mod, args, kwargs) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_torch_jit_trace \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1896\u001b[0m )\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1683\u001b[0m, in \u001b[0;36m_non_strict_export\u001b[0;34m(mod, args, kwargs, dynamic_shapes, preserve_module_call_signature, pre_dispatch, original_state_dict, orig_in_spec, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace, dispatch_tracing_mode)\u001b[0m\n\u001b[1;32m 1667\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) \u001b[38;5;28;01mas\u001b[39;00m (\n\u001b[1;32m 1668\u001b[0m patched_mod,\n\u001b[1;32m 1669\u001b[0m new_fake_args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1672\u001b[0m map_fake_to_real,\n\u001b[1;32m 1673\u001b[0m ):\n\u001b[1;32m 1674\u001b[0m _to_aten_func \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1675\u001b[0m _export_to_aten_ir_make_fx\n\u001b[1;32m 1676\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dispatch_tracing_mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmake_fx\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1681\u001b[0m )\n\u001b[1;32m 1682\u001b[0m )\n\u001b[0;32m-> 1683\u001b[0m aten_export_artifact \u001b[38;5;241m=\u001b[39m \u001b[43m_to_aten_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# type: ignore[operator]\u001b[39;49;00m\n\u001b[1;32m 1684\u001b[0m \u001b[43m \u001b[49m\u001b[43mpatched_mod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1685\u001b[0m \u001b[43m \u001b[49m\u001b[43mnew_fake_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1686\u001b[0m \u001b[43m \u001b[49m\u001b[43mnew_fake_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1687\u001b[0m \u001b[43m \u001b[49m\u001b[43mfake_params_buffers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1688\u001b[0m \u001b[43m \u001b[49m\u001b[43mnew_fake_constant_attrs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1689\u001b[0m \u001b[43m \u001b[49m\u001b[43mproduce_guards_callback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_produce_guards_callback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1690\u001b[0m \u001b[43m \u001b[49m\u001b[43mtransform\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_tuplify_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1691\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1692\u001b[0m \u001b[38;5;66;03m# aten_export_artifact.constants contains only fake script objects, we need to map them back\u001b[39;00m\n\u001b[1;32m 1693\u001b[0m aten_export_artifact\u001b[38;5;241m.\u001b[39mconstants \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 1694\u001b[0m fqn: map_fake_to_real[obj] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, FakeScriptObject) \u001b[38;5;28;01melse\u001b[39;00m obj\n\u001b[1;32m 1695\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m fqn, obj \u001b[38;5;129;01min\u001b[39;00m aten_export_artifact\u001b[38;5;241m.\u001b[39mconstants\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 1696\u001b[0m }\n",
+ "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:672\u001b[0m, in \u001b[0;36m_export_to_aten_ir\u001b[0;34m(mod, fake_args, fake_kwargs, fake_params_buffers, constant_attrs, produce_guards_callback, transform, pre_dispatch, decomp_table, _check_autograd_state, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 670\u001b[0m produce_guards_callback(gm)\n\u001b[1;32m 671\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (ConstraintViolationError, ValueRangeError) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m--> 672\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m UserError(UserErrorType\u001b[38;5;241m.\u001b[39mCONSTRAINT_VIOLATION, \u001b[38;5;28mstr\u001b[39m(e)) \u001b[38;5;66;03m# noqa: B904\u001b[39;00m\n\u001b[1;32m 674\u001b[0m \u001b[38;5;66;03m# Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature.\u001b[39;00m\n\u001b[1;32m 675\u001b[0m \u001b[38;5;66;03m# Overwrite output specs afterwards.\u001b[39;00m\n\u001b[1;32m 676\u001b[0m flat_fake_args \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_leaves((fake_args, fake_kwargs))\n",
+ "\u001b[0;31mUserError\u001b[0m: Constraints violated (token_len)! For more information, run with TORCH_LOGS=\"+dynamic\".\n - Not all values of token_len = L['args'][0][0].size()[0] in the specified range are valid because token_len was inferred to be a constant (143).\nSuggested fixes:\n token_len = 143"
+ ]
}
],
"source": [
- "model[\"predictor\"].shared"
+ "os.environ['TORCH_LOGS'] = '+dynamic'\n",
+ "os.environ['TORCH_LOGS'] = '+export'\n",
+ "class test(torch.nn.Module):\n",
+ " def __init__(self, model, voicepack):\n",
+ " super().__init__()\n",
+ " self.model = model\n",
+ " self.voicepack = voicepack\n",
+ " self.model.text_encoder.lstm.flatten_parameters()\n",
+ " \n",
+ " def forward(self, tokens0):\n",
+ " tokens = tokens0.unsqueeze(0)\n",
+ " print(tokens.shape)\n",
+ " # speed = 1.\n",
+ " # # tokens = torch.nn.functional.pad(tokens, (0, 510 - tokens.shape[-1]))\n",
+ " # device = tokens.device\n",
+ " input_lengths = torch.LongTensor([tokens0.shape[-1]]).to(device)\n",
+ "\n",
+ " # text_mask = length_to_mask(input_lengths).to(device)\n",
+ " # bert_dur = self.model['bert'](tokens, attention_mask=(~text_mask).int())\n",
+ "\n",
+ " # d_en = self.model[\"bert_encoder\"](bert_dur).transpose(-1, -2)\n",
+ "\n",
+ " # ref_s = self.voicepack[tokens.shape[1]]\n",
+ " # s = ref_s[:, 128:]\n",
+ "\n",
+ " # d = self.model[\"predictor\"].text_encoder.inference(d_en, s)\n",
+ " # x, _ = self.model[\"predictor\"].lstm(d)\n",
+ "\n",
+ " # duration = self.model[\"predictor\"].duration_proj(x)\n",
+ " # duration = torch.sigmoid(duration).sum(axis=-1) / speed\n",
+ " # pred_dur = torch.round(duration).clamp(min=1).long()\n",
+ " \n",
+ " # c_start = F.pad(pred_dur,(1,0), \"constant\").cumsum(dim=1)[0,0:-1]\n",
+ " # c_end = c_start + pred_dur[0,:]\n",
+ " # indices = torch.arange(0, pred_dur.sum().item()).long().to(device)\n",
+ "\n",
+ " # pred_aln_trg_list=[]\n",
+ " # for cs, ce in zip(c_start, c_end):\n",
+ " # row = torch.where((indices>=cs) & (indices