1
+ import copy
1
2
from typing import Any
2
3
3
4
from pytensor .graph .basic import Variable
4
5
from pytensor .link .basic import JITLinker
6
+ from pytensor .link .utils import unique_name_generator
5
7
6
8
7
9
class PytorchLinker (JITLinker ):
8
10
"""A `Linker` that compiles NumPy-based operations using torch.compile."""
9
11
12
+ def __init__ (self , * args , ** kwargs ):
13
+ super ().__init__ (* args , ** kwargs )
14
+ self .gen_functors = []
15
+
10
16
def input_filter (self , inp : Any ) -> Any :
11
17
from pytensor .link .pytorch .dispatch import pytorch_typify
12
18
@@ -18,14 +24,68 @@ def output_filter(self, var: Variable, out: Any) -> Any:
18
24
def fgraph_convert (self , fgraph , input_storage , storage_map , ** kwargs ):
19
25
from pytensor .link .pytorch .dispatch import pytorch_funcify
20
26
27
+ # We want to have globally unique names
28
+ # across the entire pytensor graph, not
29
+ # just the subgraph
30
+ generator = unique_name_generator (["torch_linker" ])
31
+
32
+ # Ensure that torch is aware of the generated
33
+ # code so we can compile without graph breaks
34
+ def conversion_func_register (* args , ** kwargs ):
35
+ functor = pytorch_funcify (* args , ** kwargs )
36
+ name = kwargs ["unique_name" ](functor )
37
+ self .gen_functors .append ((f"_{ name } " , functor ))
38
+ return functor
39
+
40
+ built_kwargs = {
41
+ "unique_name" : generator ,
42
+ "conversion_func" : conversion_func_register ,
43
+ ** kwargs ,
44
+ }
21
45
return pytorch_funcify (
22
- fgraph , input_storage = input_storage , storage_map = storage_map , ** kwargs
46
+ fgraph , input_storage = input_storage , storage_map = storage_map , ** built_kwargs
23
47
)
24
48
25
49
def jit_compile (self , fn ):
26
50
import torch
27
51
28
- return torch .compile (fn )
52
+ class wrapper :
53
+ """
54
+ Pytorch would fail compiling our method when trying
55
+ to resolve some of the methods returned from dispatch
56
+ calls. We want to be careful to not leak the methods,
57
+ so this class just holds them and provisions the expected
58
+ location accordingly
59
+
60
+ https://discuss.pytorch.org/t/closures-are-being-gcd-and-causing-failures-to-compile/213319
61
+ """
62
+
63
+ def __init__ (self , fn , gen_functors ):
64
+ self .fn = torch .compile (fn )
65
+ self .gen_functors = copy .copy (gen_functors )
66
+
67
+ def __call__ (self , * args , ** kwargs ):
68
+ import pytensor .link .utils
69
+
70
+ # set attrs
71
+ for n , fn in self .gen_functors :
72
+ setattr (pytensor .link .utils , n [1 :], fn )
73
+
74
+ res = self .fn (* args , ** kwargs )
75
+
76
+ # unset attrs
77
+ for n , _ in self .gen_functors :
78
+ if getattr (pytensor .link .utils , n [1 :], False ):
79
+ delattr (pytensor .link .utils , n [1 :])
80
+
81
+ return res
82
+
83
+ def __del__ (self ):
84
+ del self .gen_functors
85
+
86
+ res = wrapper (fn , self .gen_functors )
87
+ self .gen_functors = []
88
+ return res
29
89
30
90
def create_thunk_inputs (self , storage_map ):
31
91
thunk_inputs = []
0 commit comments