script()
Scripting 一個函數或 nn.Module
。使用 TorchScript 編譯器將其編譯為 TorchScript 程式碼,並返回 ScriptModule 或 ScriptFunction。
Scripting dictionary 或 list 會將其中的資料複製到 TorchScript 實例中,然後在 Python 和 TorchScript 之間可以透過 reference 傳遞,複製 overhead 為零。
script(obj,
optimize=None,
_frames_up=0,
_rcb=None,
example_inputs=None)
參數 | 說明 |
---|---|
obj | (Callable, class, or nn.Module ) – 要編譯的 nn.Module 、函數、class type, dictionary or list。(將程式碼轉換為 TorchScript。) 如果存在 dynamic control flow (例如: 循環或條件),則 scripting process 將捕獲該結構。 |
optimize | 已棄用且沒有任何效果。 |
_frames_up | |
_rcb | |
example_inputs | (List[Tuple] , Dict[Callable, List[Tuple]] , None ) – 提供範例輸入來註解函數或提供 nn.Module 的參數。 這些範例輸入不用於計算,但它們幫助 scripting process 理解函數輸入的形狀和類型。 對於在 scripting process 中,能更有效地產生程式碼和更好地檢查錯誤。 |
返回值:
- 如果
obj
是nn.Module
,則返回 ScriptModule 物件。返回的 ScriptModule 將具有與原始nn.Module
相同的 sub-modules 和參數集。 - 如果
obj
是獨立函數,則會返回 ScriptFunction。 - 如果
obj
是 dict,則返回torch._C.ScriptDict
的實例。 - 如果
obj
是一個 list,則返回torch._C.ScriptList
的實例。
將 PyTorch 程式碼轉換為 TorchScript 有兩種主要機制:
- Tracing:在對樣本輸入執行操作時,捕捉 computational graph。它很簡單,無法捕獲 control-flow 語句。
- Scripting:分析 Python 函數的原始程式碼,並建立 graph。它捕獲控制流但需要 type annotations。
在以下情況下,您應該考慮使用 torch.jit.script
:
torch.jit.trace
),不支援的 dynamic control flows (例如: if
或語句) 。
請記住,TorchScript 的主要目標通常是讓 PyTorch 模型的部署和最佳化變得更容易。始終測試腳本化模型,以確保其在轉換後按預期運行。
Scripting 後,錯誤將引用 TorchScript 程式碼,而不是原始 Python 程式碼。這可能會讓 debugging 變得有點困難。在轉換為 TorchScript 之前,請考慮以 PyTorch 形式徹底 debugging 模型。範例 – Scripting 函數
Decorator 裝飾器將透過編譯函數體來建構 ScriptFunction。
import torch
@torch.jit.script
def add_tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a + b
# Now `add_tensors` is a TorchScript function, and we can call it like a normal function.
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])
result = add_tensors(a, b)
print(result)
print(type(add_tensors))
print(add_tensors.code) # See the compiled graph as Python code
執行結果:
tensor([5., 7., 9.])
<class 'torch.jit.ScriptFunction'>
Output:
def add_tensors(a: Tensor,
b: Tensor) -> Tensor:
return torch.add(a, b)
import torch
@torch.jit.script
def add_tensors(x):
for i in range(x.size(0)):
x += i
return x
x = torch.tensor([1, 2, 3])
result = add_tensors(x)
print(result) # Output: tensor([4, 5, 6])
範例 – Scripting nn.Module
預設情況下,Scripting nn.Module
將編譯 forward
方法,並遞迴編譯 forward
呼叫的任何方法、submodules 和函數。如果 nn.Module
僅使用 TorchScript 支援的功能,則無需變更原始模組程式碼。 否則 script 將建構 ScriptModule,其中包含原始模組的屬性、參數和方法的副本。
import torch
class MyModule(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
# Convert to TorchScript
scripted_module = torch.jit.script(MyModule())
# Use the model just like before
output = scripted_module(torch.tensor([1, -2, 3]))
print(output) # Output: tensor([ 1, -2, 3])
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x):
return self.linear(x)
model = MyModel()
scripted_model = torch.jit.script(model)
# Save the scripted model
scripted_model.save("model.pt")
# Load and use the model
loaded_model = torch.jit.load("model.pt")
x = torch.rand(1, 10)
output = loaded_model(x)
print(output)
執行結果:
tensor([[-0.0331, 0.4932, 0.2000, 0.0149, 0.7165]],
grad_fn=<AddmmBackward0>)
範例 – Mixing tracing and scripting
TorchScript 是一種從 PyTorch 程式碼建立 serializable 和 optimizable 模型的方法。將 PyTorch 程式碼轉換為 TorchScript 有兩種主要方法:
- Tracing:使用 "example input" 來追蹤對其執行的 operations。它適用於沒有 control flows 的模型。
- Scripting:按原樣轉換程式碼,包括 control flows。適用於具有條件和循環結構的模型。
但有時,您可能有一個很乾脆 (traceable) 的模組,但由於 control flows 的原因,有一些小部分需要 scripting。在這些情況下,您可以結合 tracing 和 scripting。
- Trace the Main Module
- 使用 Control Flow Script Sub-modules
這種混合方法利用了模型簡單部分的追蹤便利性,以及對那些更動態或具有 control flow 部分的 scripting 的強大功能。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.jit as jit
class SomeModule(nn.Module):
def __init__(self):
super(SomeModule, self).__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
# This conditional is hard to capture with tracing
if x.sum() > 0:
return self.linear(x)
else:
return -self.linear(x)
class MainModule(nn.Module):
def __init__(self):
super(MainModule, self).__init__()
self.some_module = SomeModule()
self.other_linear = nn.Linear(10, 5)
def forward(self, x):
x = self.some_module(x)
return self.other_linear(x)
# Convert the sub-module using scripting
scripted_some_module = jit.script(SomeModule())
# Replace the original module with the scripted version in MainModule
main_module = MainModule()
main_module.some_module = scripted_some_module
# Now, trace the main module
example_input = torch.rand(1, 10)
traced_main_module = jit.trace(main_module, example_input)
print(traced_main_module)
# Get the output for some input
output = traced_main_module(example_input)
print(output)
執行結果:
MainModule(
original_name=MainModule
(some_module): RecursiveScriptModule(
original_name=SomeModule
(linear): RecursiveScriptModule(original_name=Linear))
(other_linear): Linear(original_name=Linear))
tensor([[-0.1702, 0.7664, 0.0461, -0.0591, -0.3725]],
grad_fn=<AddmmBackward0>)
trace()
記錄給定一組輸入時執行的 operations,但不會捕捉 control flow (例如: if 語句或循環)。
對於 control flow,應使用 torch.jit.script
。
torch.jit.trace(func,
example_inputs=None,
optimize=None,
check_trace=True,
check_inputs=None,
check_tolerance=1e-05,
strict=True,
_force_outplace=False,
_module_class=None,
_compilation_unit=<torch.jit.CompilationUnit object>,
example_kwarg_inputs=None,
_store_inputs=True)
參數 | 說明 |
---|---|
func | (Callable or torch.nn.Module ) – 您想要 trace 的 Python 函數或 torch.nn.Module ,將與 example_inputs 一起執行的 。func 參數和回傳值必須是 tensors 或包含 tensors 的 tuples。當 module passed torch.jit.trace 時,僅運行和追蹤 forward method。 |
example_inputs | (Tuple or torch.Tensor ) – 輸入 tensors 的元組或單一輸入 tensor,將在 tracing 時,傳遞給函數。用於運行 function/module 並記錄 operations,以進行追蹤。如果您的 function/module 使用多個參數,您應該將它們作為 tuple 傳遞。 當值為 None 時,應指定 example_kwarg_inputs 。 |
optimize | 已棄用且沒有任何效果。 |
check_trace | (bool ) – 檢查 traced code 是否產生正確的輸出。該 function/module 使用提供的 example_inputs 運行,並將結果與 traced code 的輸出進行比較,以確保準確性。如果您的 network 包含 non- deterministic ops (不確定性操作),或者檢查器失敗但您確定 network 是正確的,您可能需要停用此功能。 |
check_inputs | (List of tuples) – 附加輸入,幫助驗證 traced module 在不同輸入下的行為是否正確。 如果您想在多組輸入上測試 traced code 的準確性,這將非常有用。 每個 tuple 相當於在 example_inputs 中指定的一組輸入參數。為了獲得最佳結果,請傳入一組 check_inputs ,代表您希望 network 看到的形狀空間和輸入類型。如果未指定,則使用原始 example_inputs 進行檢查。 |
check_tolerance | (float ) – 確定驗證 traced code 輸出時的容錯能力。如果原始函數的輸出與 traced code 的輸出之間的差異在此容差範圍內,則認為 trac 是準確的。如果由於已知事件 (例如: operator fusion) 導致結果在數值上出現偏差,這可以用來放鬆檢查器的嚴格性。 |
strict | (bool ) – 是否以嚴格模式運行追蹤器;嚴格遵守 TorchScript 架構。如果設定為 False ,function/module 可以使用 Python-only 的功能,但產生的 traced code 可能無法匯出,以部署到 Python 之外。只有當您希望追蹤器記錄可變容器類型 (list/dict),並且您確定您在問題中使用的容器是常數結構,且不會用作 control flow ( if , for ) 的狀況。 |
_force_outplace | 一個 debug 標誌,當設定為 True 時,可確保所有 ops 都不合適。 |
_module_class | 建立 traced script module 時,使用的自訂 torch.nn.Module class 類別。主要供內部使用,不是典型使用者需要修改的內容。 |
_compilation_unit | 將儲存 traced TorchScript 函數的編譯單元。是一個內部參數,典型使用者很少使用。 |
example_kwarg_inputs | (Dict) – example_inputs 的一組關鍵字參數,在追蹤時將傳遞給函數。應指定此參數或 example_inputs 。Dict 將透過追蹤函數的參數名稱進行 unpacking。如果 keys of the dict 與追蹤函數的參數名稱不匹配,則會引發運行時異常。 |
返回值:
- 如果
func
是nn.Module
或forward()
方法,則將返回 ScriptModule 物件。返回的 ScriptModule 將具有與原始 nn.Module 相同的 sub-modules 集和參數集。 - 如果
func
是獨立函數,則傳回 ScriptFunction。
Warning
- 追蹤僅正確記錄不依賴資料的函數和模組 (例如: 對 tensors 中的資料沒有條件),且不具有任何未追蹤的外部依賴項 (例如: 執行 input/output 或存取全域變數)。
- 追蹤僅記錄給定函數在給定 tensors 上運行時完成的 operations。因此,傳回的 ScriptModule 將始終在任何輸入上執行相同的 traced graph。當您的 module 預計根據輸入或 module 狀態運行不同的operations sets 時,這會產生一些重要的影響。例如:
- 追蹤不會記錄任何 control flow (
if
,for
),但有時 control flow 實際上是模型本身的一部分。例如: recurrent network 是輸入 sequence 長度上的循環 (可能是動態的)。 - 在傳回的 ScriptModule 中,無論 ScriptModule 處於哪種模式,
training
和eval
模式下的行為將始終表現得像處於追蹤期間所處的模式一樣。 在這種情況下,追蹤是不合適的,並且torch.jit.script
是更好的選擇。如果追蹤此類模型,您可能會在後續呼叫模型時,默默地得到不正確的結果。當做一些可能導致產生不正確追蹤的事情時,追蹤器將嘗試發出警告。
範例 – Scripting function
import torch
def foo(x, y):
return 2 * x + y
# "traced_foo" can now be run with the TorchScript interpreter
# or saved and loaded in a Python-free environment
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
# Define some input tensors
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([0.5, 1.5, 2.5])
# Call the traced function with the input tensors
output = traced_foo(x, y)
print(output)
執行結果:
tensor([2.5000, 5.5000, 8.5000])
範例 – Scripting nn.Module
import torch
import torch.nn as nn
class ModelWithActivation(nn.Module):
def __init__(self):
super(ModelWithActivation, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
x = self.fc(x)
return torch.relu(x)
model = ModelWithActivation()
example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
print(traced_model)
# Get the output for some input
output = traced_model(example_input)
print(output)
執行結果:
ModelWithActivation(
original_name=ModelWithActivation
(fc): Linear(original_name=Linear))
tensor([[0.0000, 0.0354, 0.0000, 0.3064, 0.0000]], grad_fn=<ReluBackward0>)
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleNet()
# Tracing the model
example_input = torch.rand(1, 10) # Example input tensor
traced_model = torch.jit.trace(model, example_kwarg_inputs={'x': example_input})
test_input = torch.rand(1, 10)
output = traced_model(test_input)
print(output)
執行結果:
tensor([[-0.0592, 0.0850, -0.1666, 0.0536, -0.1865, -0.0310, -0.0330, 0.1313,
-0.0911, -0.1862]], grad_fn=<AddmmBackward0>)
範例 – Model with Nested Modules
import torch
import torch.nn as nn
class ModelWithActivation(nn.Module):
def __init__(self):
super(ModelWithActivation, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
x = self.fc(x)
return torch.relu(x)
class NestedModel(nn.Module):
def __init__(self):
super(NestedModel, self).__init__()
self.submodule = ModelWithActivation()
def forward(self, x):
return self.submodule(x)
model = NestedModel()
example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
print(traced_model)
# Get the output for some input
output = traced_model(example_input)
print(output)
執行結果:
NestedModel(
original_name=NestedModel
(submodule): ModelWithActivation(
original_name=ModelWithActivation
(fc): Linear(original_name=Linear)))
tensor([[0.8534, 0.2388, 0.0000, 0.0665, 0.6416]], grad_fn=<ReluBackward0>)
範例 – Control flow (錯誤用法)
import torch
def foo_with_control_flow(x):
if x.sum() > 0:
return x * 20
else:
return x + 20
# Tracing with some input (not be correctly captured)
traced_foo = torch.jit.trace(foo_with_control_flow, torch.tensor([-1.0, -1.0]))
print(traced_foo(torch.tensor([-1.0, -1.0]))) # This will give the expected output
print(traced_foo(torch.tensor([1.0, 1.0]))) # This will NOT give the expected output
執行結果:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.
We can't record the data flow of Python values, so this value will be treated as a constant in the future.
This means that the trace might not generalize to other inputs!
if x.sum() > 0:
tensor([19., 19.])
tensor([21., 21.])
script_if_tracing()
在追蹤期間首次調用 fn
時,編譯 fn
函數/方法。
由於許多編譯器內建函數的延遲初始化,script
在首次呼叫時具有不可忽略的啟動時間。因此,您不應在 library 程式碼中使用它。但是,您可能希望 library 的部分內容能夠進行 tracing,即使它們使用 control flow。在這些情況下,您應該使用 @script_if_tracing
來取代 @script
。
script_if_tracing(fn)
Returns:
如果在 tracing 期間調用,則傳回由 torch.jit.script
建立的 ScriptFunction
。否則,傳回原始函數 fn
。
torch.jit.script_if_tracing(fn)
提供了一種僅在 traced 函數時,才使用 TorchScript 有條件地編譯函數的方法。這有助於您在常規執行期間避免 torch.jit.script
overhead (開銷),但仍希望在 tracing 期間,scripted 函數的情況。
torch.jit.script_if_tracing
非常有用。
TorchScript 是一種從 PyTorch 程式碼建立 serializable 和 optimizable models 的方法。將 PyTorch 模型轉換為 TorchScript 主要有兩種方法:
- Tracing (
torch.jit.trace
):透過記錄 forward pass (前饋) 期間執行的 operations,來擷取 computation graph。 - Scripting (
torch.jit.script
):將 Python 程式碼 (包括 control flows) 轉換為 TorchScript。
Example 1: Function with control flow
import torch
# A simple function with control flow, Applying torch.jit.script_if_tracing
@torch.jit.script_if_tracing
def my_function(x):
if x.sum() > 0:
return x * 2
else:
return x / 2
def wrapper_function(x):
return my_function(x)
# Tracing
x = torch.tensor([-1.0, -2.0, 1.0])
traced_wrapper = torch.jit.trace(wrapper_function, (x,))
# Inspecting the graph will show that `my_function` has been inlined into the graph
print(traced_wrapper.graph)
print(traced_wrapper(x))
執行結果:
graph(%x : Float(3, strides=[1], requires_grad=0, device=cpu)):
%1 : Function = prim::Constant[name="my_function"]()
%2 : Tensor = prim::CallFunction(%1, %x)
return (%2)
tensor([-0.5000, -1.0000, 0.5000])
Example 2: Function with control flow
import torch
import torch.nn as nn
class CustomNetwork(nn.Module):
def __init__(self):
super(CustomNetwork, self).__init__()
self.fc = nn.Linear(10, 5)
@staticmethod
@torch.jit.script_if_tracing
def custom_operation(x):
if x.mean() > 0:
return x * 2
else:
return x / 2
def forward(self, x):
# Apply the scripted custom operation
x = self.custom_operation(x)
return self.fc(x)
# Instantiate and test the network
model = CustomNetwork()
x = torch.randn(1, 10)
traced_model = torch.jit.trace(model, (x,))
# Inspecting the graph will show that the custom_operation has been inlined into the graph
print(traced_model.graph)
print(traced_model(x))
執行結果:
graph(%self.1 : __torch__.CustomNetwork,
%x : Float(1, 10, strides=[10, 1], requires_grad=0, device=cpu)):
%fc : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="fc"](%self.1)
%9 : Function = prim::Constant[name="custom_operation"]()
%input : Tensor = prim::CallFunction(%9, %x)
%17 : Tensor = prim::CallMethod[name="forward"](%fc, %input)
return (%17)
tensor([[-0.1302, -1.4549, -2.0248, -1.6586, 0.9151]],
grad_fn=<AddmmBackward0>)
trace_module()
Trace a module 並返回可執行 ScriptModule。
當 module 傳遞到 trace()
時,僅運行和 traced forward
方法。使用 trace_module()
,您可以指定方法名稱的 dictionary,作為 trace 參數的 example inputs。
trace_module(mod,
inputs,
optimize=None,
check_trace=True,
check_inputs=None,
check_tolerance=1e-05,
strict=True,
_force_outplace=False,
_module_class=None,
_compilation_unit=<torch.jit.CompilationUnit object>,
example_inputs_is_kwarg=False,
_store_inputs=True)
參數 | 說明 |
---|---|
mod | (nn.Module ) – 一個 torch.nn.Module ,其中包含名稱在 inputs 中指定的方法。給定的方法將被編譯為單一 ScriptModule 的一部分。 |
inputs | (dict) – 一個 dict,其中包含以 mod 中的方法名稱索引的 sample inputs。輸入在追蹤時將傳遞給名稱與 inputs’ keys 相對應的方法。 {'forward':example_forward_input,'method2':example_method2_input} |
check_trace | (bool )– 檢查 trace 程式碼的相同輸入,是否產生相同的輸出。如果您的 network 包含 non- deterministic ops (不確定性操作),或者檢查器失敗但您確定 network 是正確的,您可能需要停用此功能。 |
check_inputs | (List of tuples) – 附加輸入,幫助驗證 traced module 在不同輸入下的行為是否正確。 如果您想在多組輸入上測試 traced code 的準確性,這將非常有用。 每個 tuple 相當於在 example_inputs 中指定的一組輸入參數。為了獲得最佳結果,請傳入一組 check_inputs ,代表您希望 network 看到的形狀空間和輸入類型。如果未指定,則使用原始 example_inputs 進行檢查。 |
check_tolerance | (float ) – 確定驗證 traced code 輸出時的容錯能力。如果原始函數的輸出與 traced code 的輸出之間的差異在此容差範圍內,則認為 trac 是準確的。如果由於已知事件 (例如: operator fusion) 導致結果在數值上出現偏差,這可以用來放鬆檢查器的嚴格性。 |
example_inputs_is_kwarg | (bool ) – 如果為 True ,則 inputs dictionary 中的輸入,將被視為關鍵字參數。 |
strict | (bool ) – 如果為 True ,強制 module's parameters 和 buffers 在 Python module 和 JIT module 中是相同的實例。 |
optimize | 已被棄用並且沒有任何效果。 |
_force_outplace | 用於 debugging 的內部參數。它迫使 operators out-of-place。通常不用於標準 module tracing。 |
_module_class | 指定要 traced class of the module 的內部參數。一般情況下,使用者無需設定該參數。 |
_store_inputs | 決定是否將 tracing 期間使用的輸入,儲存在 traced module 中。這主要是為了內部使用和 debugging。 |
Returns:
一個 ScriptModule
物件,具有包含 traced 程式碼的單一 forward
方法。當 func
是torch.nn.Module
時,返回的 ScriptModule
將具有與 func
相同的 sub-modules set 和參數集。
當您有一個複雜 module,需要 traced 多個方法時,請使用 trace_module
。它提供的靈活性 trace 更多,不僅僅是 forward
方法。
trace()
,特別是當您只需要 trace forward
方法或單一函數時。它更加簡單,適合快速 tracing tasks。
Example 1: Tracing a Simple Neural Network
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 3)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
def auxiliary_method(self, x):
# An example auxiliary method
return torch.relu(self.fc1(x))
# Initialize the network
net = SimpleNet()
# Example inputs for the methods
example_input_forward = torch.rand(1, 10)
example_input_auxiliary = torch.rand(1, 10)
# Trace the module
traced_net = torch.jit.trace_module(
net,
inputs={'forward': example_input_forward,
'auxiliary_method': example_input_auxiliary})
# Save the traced model
traced_net.save("traced_simple_net.pt")
# Load the traced model
loaded_net = torch.jit.load("traced_simple_net.pt")
# Using the model
sample_input = torch.rand(1, 10)
output = loaded_net(sample_input)
aux_output = loaded_net.auxiliary_method(sample_input)
print(output)
print(aux_output)
執行結果:
tensor([[ 0.0775, 0.0422, -0.3180]], grad_fn=<AddmmBackward0>)
tensor([[0.0000, 0.0000, 0.3084, 0.0000, 0.0000]], grad_fn=<ReluBackward0>)
範例 – check_inputs, example_inputs_is_kwarg 參數
Example 1: check_inputs
參數
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fc = nn.Linear(16 * 16 * 16, 10) # Assuming input images are 32x32
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = x.view(x.size(0), -1) # Flatten the tensor
x = self.fc(x)
return x
# Create a model instance and move it to the appropriate device (CPU or GPU)
model = SimpleCNN().to('cpu')
# Define example inputs
example_input = torch.rand(1, 3, 32, 32) # Example input tensor
# Additional inputs to check, structured as a list of dictionaries
additional_inputs = [
{'forward': (torch.rand(1, 3, 32, 32),)},
{'forward': (torch.rand(2, 3, 32, 32),)},]
# Tracing the model
traced_model = torch.jit.trace_module(
model,
inputs={'forward': example_input},
check_inputs=additional_inputs,)
# Save the traced model
traced_model.save("traced_simple_cnn.pt")
print(traced_model_kwarg(example_input).shape)
# Output: torch.Size([1, 10])
Example 2: example_inputs_is_kwarg
參數example_inputs_is_kwarg
參數用於指定 inputs
參數中,提供的輸入是否為關鍵字參數。當您正在追蹤的方法需要關鍵字參數時,此參數至關重要。
在許多典型的 use cases 中,尤其是像我們上一個範例中的簡單模型一樣,輸入作為位置參數傳遞,並且 example_inputs_is_kwarg
無論設定為 True
還是 False
都沒有什麼區別。
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fc = nn.Linear(16 * 16 * 16, 10)
def forward(self, input_tensor):
x = self.pool(F.relu(self.conv1(input_tensor)))
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Create a model instance
model = SimpleCNN().to('cpu')
# Define example inputs as keyword arguments
example_input_kwarg = {'input_tensor': torch.rand(1, 3, 32, 32)}
# Additional inputs to check
additional_inputs_kwarg = [
{'forward': {'input_tensor': torch.rand(1, 3, 32, 32)}},
{'forward': {'input_tensor': torch.rand(2, 3, 32, 32)}},
]
# Tracing the model with keyword arguments
traced_model_kwarg = torch.jit.trace_module(
model,
inputs={'forward': example_input_kwarg},
check_inputs=additional_inputs_kwarg,
example_inputs_is_kwarg=True)
# Save the traced model
traced_model_kwarg.save("traced_simple_cnn_kwarg.pt")
print(traced_model_kwarg(example_input_kwarg['input_tensor']).shape)
# Output: torch.Size([1, 10])
參考資料
torch.jit.script — PyTorch 2.0 documentation
torch.jit.script_if_tracing — PyTorch 2.1 documentation
torch.jit.trace — PyTorch 2.1 documentation
torch.jit.trace_module — PyTorch 2.1 documentation