1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
|
This patch fixes some scripts for generating source files. For
gen_jit_decompositions.py, gen_mobile_upgraders.py and
gen_jit_shape_functions.py, which depend on the compiled PyTorch library, the
option to generate "dummy" source files is added for the initial build, which
is later corrected. codegen_external.py is patched to avoid duplicate
functions and add the static keyword as in the existing generated file.
diff --git a/tools/gen_flatbuffers.sh b/tools/gen_flatbuffers.sh
index cc0263dbbf..ac34e84b82 100644
--- a/tools/gen_flatbuffers.sh
+++ b/tools/gen_flatbuffers.sh
@@ -1,13 +1,13 @@
#!/bin/bash
ROOT=$(pwd)
-FF_LOCATION="$ROOT/third_party/flatbuffers"
-cd "$FF_LOCATION" || exit
-mkdir build
-cd build || exit
-cmake ..
-cmake --build . --target flatc
-mkdir -p "$ROOT/build/torch/csrc/jit/serialization"
-./flatc --cpp --gen-mutable --scoped-enums \
+#FF_LOCATION="$ROOT/third_party/flatbuffers"
+#cd "$FF_LOCATION" || exit
+#mkdir build
+#cd build || exit
+#cmake ..
+#cmake --build . --target flatc
+#mkdir -p "$ROOT/build/torch/csrc/jit/serialization"
+flatc --cpp --gen-mutable --scoped-enums \
-o "$ROOT/torch/csrc/jit/serialization" \
-c "$ROOT/torch/csrc/jit/serialization/mobile_bytecode.fbs"
echo '// @generated' >> "$ROOT/torch/csrc/jit/serialization/mobile_bytecode_generated.h"
diff --git a/torch/csrc/jit/tensorexpr/codegen_external.py b/torch/csrc/jit/tensorexpr/codegen_external.py
index 5dcf1b2840..0e20b0c102 100644
--- a/torch/csrc/jit/tensorexpr/codegen_external.py
+++ b/torch/csrc/jit/tensorexpr/codegen_external.py
@@ -21,9 +21,14 @@ def gen_external(native_functions_path, tags_path, external_path):
native_functions = parse_native_yaml(native_functions_path, tags_path)
func_decls = []
func_registrations = []
- for func in native_functions:
+ done_names = set()
+ for func in native_functions[0]:
schema = func.func
name = schema.name.name.base
+ if name in done_names:
+ continue
+ else:
+ done_names.add(name)
args = schema.arguments
# Only supports extern calls for functions with out variants
if not schema.is_out_fn():
@@ -63,7 +68,7 @@ def gen_external(native_functions_path, tags_path, external_path):
# print(tensor_decls, name, arg_names)
func_decl = f"""\
-void nnc_aten_{name}(
+static void nnc_aten_{name}(
int64_t bufs_num,
void** buf_data,
int64_t* buf_ranks,
diff --git a/torchgen/decompositions/gen_jit_decompositions.py b/torchgen/decompositions/gen_jit_decompositions.py
index b42948045c..e1cfc73a5e 100644
--- a/torchgen/decompositions/gen_jit_decompositions.py
+++ b/torchgen/decompositions/gen_jit_decompositions.py
@@ -1,8 +1,12 @@
#!/usr/bin/env python3
import os
from pathlib import Path
+import sys
-from torch.jit._decompositions import decomposition_table
+if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+ from torch.jit._decompositions import decomposition_table
+else:
+ decomposition_table = {}
# from torchgen.code_template import CodeTemplate
@@ -86,7 +90,7 @@ def write_decomposition_util_file(path: str) -> None:
def main() -> None:
- pytorch_dir = Path(__file__).resolve().parents[3]
+ pytorch_dir = Path(__file__).resolve().parents[2]
upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime"
write_decomposition_util_file(str(upgrader_path))
diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py
index 362ce427d5..245056f815 100644
--- a/torchgen/operator_versions/gen_mobile_upgraders.py
+++ b/torchgen/operator_versions/gen_mobile_upgraders.py
@@ -6,10 +6,13 @@ import os
from enum import Enum
from operator import itemgetter
from pathlib import Path
+import sys
from typing import Any
-import torch
-from torch.jit.generate_bytecode import generate_upgraders_bytecode
+if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+ import torch
+ from torch.jit.generate_bytecode import generate_upgraders_bytecode
+
from torchgen.code_template import CodeTemplate
from torchgen.operator_versions.gen_mobile_upgraders_constant import (
MOBILE_UPGRADERS_HEADER_DESCRIPTION,
@@ -265,7 +268,10 @@ def construct_register_size(register_size_from_yaml: int) -> str:
def construct_version_maps(
upgrader_bytecode_function_to_index_map: dict[str, Any]
) -> str:
- version_map = torch._C._get_operator_version_map()
+ if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+ version_map = torch._C._get_operator_version_map()
+ else:
+ version_map = {}
sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return]
sorted_version_map = dict(sorted_version_map_)
@@ -381,7 +387,10 @@ def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]:
def main() -> None:
- upgrader_list = generate_upgraders_bytecode()
+ if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+ upgrader_list = generate_upgraders_bytecode()
+ else:
+ upgrader_list = []
sorted_upgrader_list = sort_upgrader(upgrader_list)
for up in sorted_upgrader_list:
print("after sort upgrader : ", next(iter(up)))
diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py
index 56a3d8bf0d..490a3ea2e7 100644
--- a/torchgen/shape_functions/gen_jit_shape_functions.py
+++ b/torchgen/shape_functions/gen_jit_shape_functions.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import os
import sys
+import importlib
from importlib.util import module_from_spec, spec_from_file_location
from itertools import chain
from pathlib import Path
@@ -18,16 +19,21 @@ you are in the root directory of the Pytorch git repo"""
if not file_path.exists():
raise Exception(err_msg) # noqa: TRY002
-spec = spec_from_file_location(module_name, file_path)
-assert spec is not None
-module = module_from_spec(spec)
-sys.modules[module_name] = module
-assert spec.loader is not None
-assert module is not None
-spec.loader.exec_module(module)
-
-bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
-shape_compute_graph_mapping = module.shape_compute_graph_mapping
+if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ assert spec is not None
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[module_name] = module
+ assert spec.loader is not None
+ assert module is not None
+ spec.loader.exec_module(module)
+
+ bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
+ shape_compute_graph_mapping = module.shape_compute_graph_mapping
+
+else:
+ bounded_compute_graph_mapping = {}
+ shape_compute_graph_mapping = {}
SHAPE_HEADER = r"""
|