TorchScript Tracing and Scripting

TorchScript Tracing and Scripting

script()


Scripting 一個函數或 nn.Module。使用 TorchScript 編譯器將其編譯為 TorchScript 程式碼,並返回 ScriptModule 或 ScriptFunction。

Scripting dictionary 或 list 會將其中的資料複製到 TorchScript 實例中,然後在 Python 和 TorchScript 之間可以透過 reference 傳遞,複製 overhead 為零。

script(obj, 
	   optimize=None, 
	   _frames_up=0, 
	   _rcb=None, 
	   example_inputs=None)
參數 說明
obj (Callable, class, or nn.Module) – 要編譯的 nn.Module、函數、class type, dictionary or list。(將程式碼轉換為 TorchScript。)
如果存在 dynamic control flow (例如: 循環或條件),則 scripting process 將捕獲該結構。
optimize 已棄用且沒有任何效果。
_frames_up
_rcb
example_inputs (List[Tuple], Dict[Callable, List[Tuple]], None) – 提供範例輸入來註解函數或提供 nn.Module 的參數。
這些範例輸入不用於計算,但它們幫助 scripting process 理解函數輸入的形狀和類型。 對於在 scripting process 中,能更有效地產生程式碼和更好地檢查錯誤。

返回值:

  1. 如果 objnn.Module,則返回 ScriptModule 物件。返回的 ScriptModule 將具有與原始 nn.Module 相同的 sub-modules 和參數集。
  2. 如果 obj 是獨立函數,則會返回 ScriptFunction。
  3. 如果 obj 是 dict,則返回 torch._C.ScriptDict 的實例。
  4. 如果 obj 是一個 list,則返回 torch._C.ScriptList 的實例。

將 PyTorch 程式碼轉換為 TorchScript 有兩種主要機制:

  1. Tracing:在對樣本輸入執行操作時,捕捉 computational graph。它很簡單,無法捕獲 control-flow 語句。
  2. Scripting:分析 Python 函數的原始程式碼,並建立 graph。它捕獲控制流但需要 type annotations。

在以下情況下,您應該考慮使用 torch.jit.script

您的模型包含 tracing alone (torch.jit.trace),不支援的 dynamic control flows (例如: if 或語句) 。

請記住,TorchScript 的主要目標通常是讓 PyTorch 模型的部署和最佳化變得更容易。始終測試腳本化模型,以確保其在轉換後按預期運行。

Scripting 後,錯誤將引用 TorchScript 程式碼,而不是原始 Python 程式碼。這可能會讓 debugging 變得有點困難。在轉換為 TorchScript 之前,請考慮以 PyTorch 形式徹底 debugging 模型。

範例 – Scripting 函數


Decorator 裝飾器將透過編譯函數體來建構 ScriptFunction。

import torch

@torch.jit.script
def add_tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    return a + b

# Now `add_tensors` is a TorchScript function, and we can call it like a normal function.
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])

result = add_tensors(a, b)
print(result)
print(type(add_tensors))
print(add_tensors.code)  # See the compiled graph as Python code

執行結果:

tensor([5., 7., 9.])
<class 'torch.jit.ScriptFunction'>

Output: 
def add_tensors(a: Tensor,
    b: Tensor) -> Tensor:
  return torch.add(a, b)

import torch

@torch.jit.script
def add_tensors(x):
    for i in range(x.size(0)):
        x += i
    return x

x = torch.tensor([1, 2, 3])
result = add_tensors(x)
print(result)  # Output: tensor([4, 5, 6])

範例 – Scripting nn.Module


預設情況下,Scripting nn.Module 將編譯 forward 方法,並遞迴編譯 forward 呼叫的任何方法、submodules 和函數。如果 nn.Module 僅使用 TorchScript 支援的功能,則無需變更原始模組程式碼。 否則 script 將建構 ScriptModule,其中包含原始模組的屬性、參數和方法的副本。

import torch

class MyModule(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

# Convert to TorchScript
scripted_module = torch.jit.script(MyModule())

# Use the model just like before
output = scripted_module(torch.tensor([1, -2, 3]))
print(output)  # Output: tensor([ 1, -2,  3])

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        return self.linear(x)

model = MyModel()
scripted_model = torch.jit.script(model)

# Save the scripted model
scripted_model.save("model.pt")

# Load and use the model
loaded_model = torch.jit.load("model.pt")
x = torch.rand(1, 10)
output = loaded_model(x)
print(output)

執行結果:

tensor([[-0.0331,  0.4932,  0.2000,  0.0149,  0.7165]],
       grad_fn=<AddmmBackward0>)

範例 – Mixing tracing and scripting


TorchScript 是一種從 PyTorch 程式碼建立 serializable 和 optimizable 模型的方法。將 PyTorch 程式碼轉換為 TorchScript 有兩種主要方法:

  1. Tracing:使用 "example input" 來追蹤對其執行的 operations。它適用於沒有 control flows 的模型。
  2. Scripting:按原樣轉換程式碼,包括 control flows。適用於具有條件和循環結構的模型。

但有時,您可能有一個很乾脆 (traceable) 的模組,但由於 control flows 的原因,有一些小部分需要 scripting。在這些情況下,您可以結合 tracing 和 scripting。

  1. Trace the Main Module
  2. 使用 Control Flow Script Sub-modules

這種混合方法利用了模型簡單部分的追蹤便利性,以及對那些更動態或具有 control flow 部分的 scripting 的強大功能。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.jit as jit

class SomeModule(nn.Module):
    def __init__(self):
        super(SomeModule, self).__init__()
        self.linear = nn.Linear(10, 10)
    
    def forward(self, x):
        # This conditional is hard to capture with tracing
        if x.sum() > 0:
            return self.linear(x)
        else:
            return -self.linear(x)

class MainModule(nn.Module):
    def __init__(self):
        super(MainModule, self).__init__()
        self.some_module = SomeModule()
        self.other_linear = nn.Linear(10, 5)
    
    def forward(self, x):
        x = self.some_module(x)
        return self.other_linear(x)

# Convert the sub-module using scripting
scripted_some_module = jit.script(SomeModule())

# Replace the original module with the scripted version in MainModule
main_module = MainModule()
main_module.some_module = scripted_some_module

# Now, trace the main module
example_input = torch.rand(1, 10)
traced_main_module = jit.trace(main_module, example_input)
print(traced_main_module)

# Get the output for some input
output = traced_main_module(example_input)
print(output)

執行結果:

MainModule(
  original_name=MainModule
  (some_module): RecursiveScriptModule(
    original_name=SomeModule
    (linear): RecursiveScriptModule(original_name=Linear))
  (other_linear): Linear(original_name=Linear))
  
tensor([[-0.1702,  0.7664,  0.0461, -0.0591, -0.3725]],
       grad_fn=<AddmmBackward0>)

trace()


記錄給定一組輸入時執行的 operations,但不會捕捉 control flow (例如: if 語句或循環)。
對於 control flow,應使用 torch.jit.script

torch.jit.trace(func, 
				example_inputs=None, 
				optimize=None, 
				check_trace=True, 
				check_inputs=None, 
				check_tolerance=1e-05, 
				strict=True, 
				_force_outplace=False, 
				_module_class=None, 
				_compilation_unit=<torch.jit.CompilationUnit object>, 
				example_kwarg_inputs=None, 
				_store_inputs=True)
參數 說明
func (Callable or torch.nn.Module) – 您想要 trace 的 Python 函數或 torch.nn.Module,將與 example_inputs 一起執行的 。
func 參數和回傳值必須是 tensors 或包含 tensors 的 tuples。當 module passed torch.jit.trace 時,僅運行和追蹤 forward method。
example_inputs (Tuple or torch.Tensor) – 輸入 tensors 的元組或單一輸入 tensor,將在 tracing 時,傳遞給函數。
用於運行 function/module 並記錄 operations,以進行追蹤。如果您的 function/module 使用多個參數,您應該將它們作為 tuple 傳遞。
當值為 None 時,應指定 example_kwarg_inputs
optimize 已棄用且沒有任何效果。
check_trace (bool) – 檢查 traced code 是否產生正確的輸出。
該 function/module 使用提供的 example_inputs 運行,並將結果與 traced code 的輸出進行比較,以確保準確性。
如果您的 network 包含 non- deterministic ops (不確定性操作),或者檢查器失敗但您確定 network 是正確的,您可能需要停用此功能。
check_inputs (List of tuples) – 附加輸入,幫助驗證 traced module 在不同輸入下的行為是否正確。
如果您想在多組輸入上測試 traced code 的準確性,這將非常有用。
每個 tuple 相當於在 example_inputs 中指定的一組輸入參數。為了獲得最佳結果,請傳入一組 check_inputs,代表您希望 network 看到的形狀空間和輸入類型。
如果未指定,則使用原始 example_inputs 進行檢查。
check_tolerance (float) – 確定驗證 traced code 輸出時的容錯能力。
如果原始函數的輸出與 traced code 的輸出之間的差異在此容差範圍內,則認為 trac 是準確的。如果由於已知事件 (例如: operator fusion) 導致結果在數值上出現偏差,這可以用來放鬆檢查器的嚴格性。
strict (bool) – 是否以嚴格模式運行追蹤器;嚴格遵守 TorchScript 架構。
如果設定為 False,function/module 可以使用 Python-only 的功能,但產生的 traced code 可能無法匯出,以部署到 Python 之外。
只有當您希望追蹤器記錄可變容器類型 (list/dict),並且您確定您在問題中使用的容器是常數結構,且不會用作 control flow (if, for) 的狀況。
_force_outplace 一個 debug 標誌,當設定為 True 時,可確保所有 ops 都不合適。
_module_class 建立 traced script module 時,使用的自訂 torch.nn.Module class 類別。
主要供內部使用,不是典型使用者需要修改的內容。
_compilation_unit 將儲存 traced TorchScript 函數的編譯單元。是一個內部參數,典型使用者很少使用。
example_kwarg_inputs (Dict) – example_inputs 的一組關鍵字參數,在追蹤時將傳遞給函數。應指定此參數或 example_inputs
Dict 將透過追蹤函數的參數名稱進行 unpacking。如果 keys of the dict 與追蹤函數的參數名稱不匹配,則會引發運行時異常。

返回值:

  • 如果 funcnn.Moduleforward() 方法,則將返回 ScriptModule 物件。返回的 ScriptModule 將具有與原始 nn.Module 相同的 sub-modules 集和參數集。
  • 如果 func 是獨立函數,則傳回 ScriptFunction。
Tracing 非常適合僅對 Tensors 和 Lists, Dictionaries, 和 Tuples of Tensors 進行 operates 的程式碼。
由於追蹤記錄 operations 是基於傳遞的輸入數據,因此可能無法準確追蹤具有副作用的 operations (例如: in-place modifications)。確保模型中的操作沒有 side effect (副作用)。

Warning

  • 追蹤僅正確記錄不依賴資料的函數和模組 (例如: 對 tensors 中的資料沒有條件),且不具有任何未追蹤的外部依賴項 (例如: 執行 input/output 或存取全域變數)。
  • 追蹤僅記錄給定函數在給定 tensors 上運行時完成的 operations。因此,傳回的 ScriptModule 將始終在任何輸入上執行相同的 traced graph。當您的 module 預計根據輸入或 module 狀態運行不同的operations sets 時,這會產生一些重要的影響。例如:
  1. 追蹤不會記錄任何 control flow (if, for),但有時 control flow 實際上是模型本身的一部分。例如: recurrent network 是輸入 sequence 長度上的循環 (可能是動態的)。
  2. 在傳回的 ScriptModule 中,無論 ScriptModule 處於哪種模式,trainingeval 模式下的行為將始終表現得像處於追蹤期間所處的模式一樣。 在這種情況下,追蹤是不合適的,並且 torch.jit.script 是更好的選擇。如果追蹤此類模型,您可能會在後續呼叫模型時,默默地得到不正確的結果。當做一些可能導致產生不正確追蹤的事情時,追蹤器將嘗試發出警告。

範例 – Scripting function


import torch

def foo(x, y):
    return 2 * x + y

# "traced_foo" can now be run with the TorchScript interpreter 
# or saved and loaded in a Python-free environment
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

# Define some input tensors
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([0.5, 1.5, 2.5])

# Call the traced function with the input tensors
output = traced_foo(x, y)
print(output)

執行結果:

tensor([2.5000, 5.5000, 8.5000])

範例 – Scripting nn.Module


import torch
import torch.nn as nn

class ModelWithActivation(nn.Module):
    def __init__(self):
        super(ModelWithActivation, self).__init__()
        self.fc = nn.Linear(10, 5)
    
    def forward(self, x):
        x = self.fc(x)
        return torch.relu(x)

model = ModelWithActivation()
example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
print(traced_model)

# Get the output for some input
output = traced_model(example_input)
print(output)

執行結果:

ModelWithActivation(
  original_name=ModelWithActivation
  (fc): Linear(original_name=Linear))

tensor([[0.0000, 0.0354, 0.0000, 0.3064, 0.0000]], grad_fn=<ReluBackward0>)

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


model = SimpleNet()

# Tracing the model
example_input = torch.rand(1, 10)  # Example input tensor
traced_model = torch.jit.trace(model, example_kwarg_inputs={'x': example_input})

test_input = torch.rand(1, 10)
output = traced_model(test_input)
print(output)

執行結果:

tensor([[-0.0592,  0.0850, -0.1666,  0.0536, -0.1865, -0.0310, -0.0330,  0.1313,
         -0.0911, -0.1862]], grad_fn=<AddmmBackward0>)

範例 – Model with Nested Modules


import torch
import torch.nn as nn

class ModelWithActivation(nn.Module):
    def __init__(self):
        super(ModelWithActivation, self).__init__()
        self.fc = nn.Linear(10, 5)
    
    def forward(self, x):
        x = self.fc(x)
        return torch.relu(x)

class NestedModel(nn.Module):
    def __init__(self):
        super(NestedModel, self).__init__()
        self.submodule = ModelWithActivation()
    
    def forward(self, x):
        return self.submodule(x)

model = NestedModel()

example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
print(traced_model)

# Get the output for some input
output = traced_model(example_input)
print(output)

執行結果:

NestedModel(
  original_name=NestedModel
  (submodule): ModelWithActivation(
    original_name=ModelWithActivation
    (fc): Linear(original_name=Linear)))

tensor([[0.8534, 0.2388, 0.0000, 0.0665, 0.6416]], grad_fn=<ReluBackward0>)

範例 – Control flow (錯誤用法)


import torch

def foo_with_control_flow(x):
    if x.sum() > 0:
        return x * 20
    else:
        return x + 20

# Tracing with some input (not be correctly captured)
traced_foo = torch.jit.trace(foo_with_control_flow, torch.tensor([-1.0, -1.0]))

print(traced_foo(torch.tensor([-1.0, -1.0])))  # This will give the expected output
print(traced_foo(torch.tensor([1.0, 1.0])))    # This will NOT give the expected output

執行結果:

TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. 
We can't record the data flow of Python values, so this value will be treated as a constant in the future. 
This means that the trace might not generalize to other inputs!
  if x.sum() > 0:

tensor([19., 19.])
tensor([21., 21.])

script_if_tracing()


在追蹤期間首次調用 fn 時,編譯 fn 函數/方法。

由於許多編譯器內建函數的延遲初始化,script 在首次呼叫時具有不可忽略的啟動時間。因此,您不應在 library 程式碼中使用它。但是,您可能希望 library 的部分內容能夠進行 tracing,即使它們使用 control flow。在這些情況下,您應該使用 @script_if_tracing 來取代 @script

script_if_tracing(fn)

Returns:
如果在 tracing 期間調用,則傳回由 torch.jit.script 建立的 ScriptFunction。否則,傳回原始函數 fn

torch.jit.script_if_tracing(fn) 提供了一種僅在 traced 函數時,才使用 TorchScript 有條件地編譯函數的方法。這有助於您在常規執行期間避免 torch.jit.script overhead (開銷),但仍希望在 tracing 期間,scripted 函數的情況。

總而言之,如果您希望避免常規函數呼叫期間的 overhead of scripting,但仍希望在 tracing 期間受益於 scripting,特別是在涉及 control flows 時,torch.jit.script_if_tracing 非常有用。

TorchScript 是一種從 PyTorch 程式碼建立 serializable 和 optimizable models 的方法。將 PyTorch 模型轉換為 TorchScript 主要有兩種方法:

  1. Tracing (torch.jit.trace):透過記錄 forward pass (前饋) 期間執行的 operations,來擷取 computation graph。
  2. Scripting (torch.jit.script):將 Python 程式碼 (包括 control flows) 轉換為 TorchScript。

Example 1: Function with control flow

import torch

# A simple function with control flow, Applying torch.jit.script_if_tracing
@torch.jit.script_if_tracing
def my_function(x):
    if x.sum() > 0:
        return x * 2
    else:
        return x / 2


def wrapper_function(x):
    return my_function(x)

# Tracing
x = torch.tensor([-1.0, -2.0, 1.0])
traced_wrapper = torch.jit.trace(wrapper_function, (x,))

# Inspecting the graph will show that `my_function` has been inlined into the graph
print(traced_wrapper.graph)
print(traced_wrapper(x))

執行結果:

graph(%x : Float(3, strides=[1], requires_grad=0, device=cpu)):
  %1 : Function = prim::Constant[name="my_function"]()
  %2 : Tensor = prim::CallFunction(%1, %x)
  return (%2)

tensor([-0.5000, -1.0000,  0.5000])

Example 2: Function with control flow

import torch
import torch.nn as nn

class CustomNetwork(nn.Module):
    def __init__(self):
        super(CustomNetwork, self).__init__()
        self.fc = nn.Linear(10, 5)

    @staticmethod
    @torch.jit.script_if_tracing
    def custom_operation(x):
        if x.mean() > 0:
            return x * 2
        else:
            return x / 2

    def forward(self, x):
        # Apply the scripted custom operation
        x = self.custom_operation(x)
        return self.fc(x)

# Instantiate and test the network
model = CustomNetwork()
x = torch.randn(1, 10)
traced_model = torch.jit.trace(model, (x,))

# Inspecting the graph will show that the custom_operation has been inlined into the graph
print(traced_model.graph)
print(traced_model(x))

執行結果:

graph(%self.1 : __torch__.CustomNetwork,
      %x : Float(1, 10, strides=[10, 1], requires_grad=0, device=cpu)):
  %fc : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="fc"](%self.1)
  %9 : Function = prim::Constant[name="custom_operation"]()
  %input : Tensor = prim::CallFunction(%9, %x)
  %17 : Tensor = prim::CallMethod[name="forward"](%fc, %input)
  return (%17)

tensor([[-0.1302, -1.4549, -2.0248, -1.6586,  0.9151]],
       grad_fn=<AddmmBackward0>)

trace_module()


Trace a module 並返回可執行 ScriptModule。
當 module 傳遞到 trace() 時,僅運行和 traced forward 方法。使用 trace_module(),您可以指定方法名稱的 dictionary,作為 trace 參數的 example inputs。

trace_module(mod, 
			 inputs, 
			 optimize=None, 
			 check_trace=True, 
			 check_inputs=None, 
			 check_tolerance=1e-05, 
			 strict=True, 
			 _force_outplace=False, 
			 _module_class=None, 
			 _compilation_unit=<torch.jit.CompilationUnit object>, 
			 example_inputs_is_kwarg=False, 
			 _store_inputs=True)
參數 說明
mod (nn.Module) – 一個 torch.nn.Module,其中包含名稱在 inputs 中指定的方法。
給定的方法將被編譯為單一 ScriptModule 的一部分。
inputs (dict) – 一個 dict,其中包含以 mod 中的方法名稱索引的 sample inputs。
輸入在追蹤時將傳遞給名稱與 inputs’ keys 相對應的方法。 {'forward':example_forward_input,'method2':example_method2_input}
check_trace (bool)– 檢查 trace 程式碼的相同輸入,是否產生相同的輸出。
如果您的 network 包含 non- deterministic ops (不確定性操作),或者檢查器失敗但您確定 network 是正確的,您可能需要停用此功能。
check_inputs (List of tuples) – 附加輸入,幫助驗證 traced module 在不同輸入下的行為是否正確。
如果您想在多組輸入上測試 traced code 的準確性,這將非常有用。
每個 tuple 相當於在 example_inputs 中指定的一組輸入參數。為了獲得最佳結果,請傳入一組 check_inputs,代表您希望 network 看到的形狀空間和輸入類型。
如果未指定,則使用原始 example_inputs 進行檢查。
check_tolerance (float) – 確定驗證 traced code 輸出時的容錯能力。
如果原始函數的輸出與 traced code 的輸出之間的差異在此容差範圍內,則認為 trac 是準確的。如果由於已知事件 (例如: operator fusion) 導致結果在數值上出現偏差,這可以用來放鬆檢查器的嚴格性。
example_inputs_is_kwarg (bool) – 如果為 True,則 inputs dictionary 中的輸入,將被視為關鍵字參數。
strict (bool) – 如果為 True,強制 module's parameters 和 buffers 在 Python module 和 JIT module 中是相同的實例。
optimize 已被棄用並且沒有任何效果。
_force_outplace 用於 debugging 的內部參數。它迫使 operators out-of-place。通常不用於標準 module tracing。
_module_class 指定要 traced class of the module 的內部參數。一般情況下,使用者無需設定該參數。
_store_inputs 決定是否將 tracing 期間使用的輸入,儲存在 traced module 中。這主要是為了內部使用和 debugging。

Returns:
一個 ScriptModule 物件,具有包含 traced 程式碼的單一 forward 方法。當 functorch.nn.Module 時,返回的 ScriptModule 將具有與 func 相同的 sub-modules set 和參數集。

當您有一個複雜 module,需要 traced 多個方法時,請使用 trace_module。它提供的靈活性 trace 更多,不僅僅是 forward 方法。

對於更簡單的用例,選擇 trace() ,特別是當您只需要 trace forward 方法或單一函數時。它更加簡單,適合快速 tracing tasks。

Example 1: Tracing a Simple Neural Network

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 3)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

    def auxiliary_method(self, x):
        # An example auxiliary method
        return torch.relu(self.fc1(x))


# Initialize the network
net = SimpleNet()

# Example inputs for the methods
example_input_forward = torch.rand(1, 10)
example_input_auxiliary = torch.rand(1, 10)

# Trace the module
traced_net = torch.jit.trace_module(
    net, 
    inputs={'forward': example_input_forward, 
            'auxiliary_method': example_input_auxiliary})

# Save the traced model
traced_net.save("traced_simple_net.pt")

# Load the traced model
loaded_net = torch.jit.load("traced_simple_net.pt")

# Using the model
sample_input = torch.rand(1, 10)
output = loaded_net(sample_input)
aux_output = loaded_net.auxiliary_method(sample_input)

print(output)
print(aux_output)

執行結果:

tensor([[ 0.0775,  0.0422, -0.3180]], grad_fn=<AddmmBackward0>)
tensor([[0.0000, 0.0000, 0.3084, 0.0000, 0.0000]], grad_fn=<ReluBackward0>)

範例 – check_inputs, example_inputs_is_kwarg 參數


Example 1: check_inputs 參數

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc = nn.Linear(16 * 16 * 16, 10)  # Assuming input images are 32x32

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        return x

# Create a model instance and move it to the appropriate device (CPU or GPU)
model = SimpleCNN().to('cpu')

# Define example inputs
example_input = torch.rand(1, 3, 32, 32)  # Example input tensor

# Additional inputs to check, structured as a list of dictionaries
additional_inputs = [
    {'forward': (torch.rand(1, 3, 32, 32),)},
    {'forward': (torch.rand(2, 3, 32, 32),)},]

# Tracing the model
traced_model = torch.jit.trace_module(
    model,
    inputs={'forward': example_input},
    check_inputs=additional_inputs,)

# Save the traced model
traced_model.save("traced_simple_cnn.pt")
print(traced_model_kwarg(example_input).shape)
# Output: torch.Size([1, 10])

Example 2: example_inputs_is_kwarg 參數
example_inputs_is_kwarg 參數用於指定 inputs 參數中,提供的輸入是否為關鍵字參數。當您正在追蹤的方法需要關鍵字參數時,此參數至關重要。
在許多典型的 use cases 中,尤其是像我們上一個範例中的簡單模型一樣,輸入作為位置參數傳遞,並且 example_inputs_is_kwarg 無論設定為 True 還是 False 都沒有什麼區別。

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc = nn.Linear(16 * 16 * 16, 10)

    def forward(self, input_tensor):
        x = self.pool(F.relu(self.conv1(input_tensor)))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Create a model instance
model = SimpleCNN().to('cpu')

# Define example inputs as keyword arguments
example_input_kwarg = {'input_tensor': torch.rand(1, 3, 32, 32)}

# Additional inputs to check
additional_inputs_kwarg = [
    {'forward': {'input_tensor': torch.rand(1, 3, 32, 32)}},
    {'forward': {'input_tensor': torch.rand(2, 3, 32, 32)}},
]

# Tracing the model with keyword arguments
traced_model_kwarg = torch.jit.trace_module(
    model,
    inputs={'forward': example_input_kwarg},
    check_inputs=additional_inputs_kwarg,
    example_inputs_is_kwarg=True)

# Save the traced model
traced_model_kwarg.save("traced_simple_cnn_kwarg.pt")
print(traced_model_kwarg(example_input_kwarg['input_tensor']).shape)
# Output: torch.Size([1, 10])

參考資料


torch.jit.script — PyTorch 2.0 documentation

torch.jit.script_if_tracing — PyTorch 2.1 documentation

torch.jit.trace — PyTorch 2.1 documentation

torch.jit.trace_module — PyTorch 2.1 documentation

Torchscript language reference — PyTorch 2.0 documentation

TorchScript Classes — PyTorch 2.0 documentation

About the author
熾焰小迅風

Great! You’ve successfully signed up.

Welcome back! You've successfully signed in.

You've successfully subscribed to XiWind 西風之劍.

Success! Check your email for magic link to sign-in.

Success! Your billing info has been updated.

Your billing was not updated.