🐯 Liger GRPO meets TRL

-


Thanks in your great work. n

Anyway, I tested the liger loss with deepspeed zero3 using Qwen/Qwen2.5-0.5B-Instruct in a bf16.
I met an shape mismatch as stated below:

n

n[rank0]: Traceback (most up-to-date call last):n[rank0]:   File "/workspace/temp.py", line 22, in <module>n[rank0]:     trainer.train()n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2238, in trainn[rank0]:     return inner_training_loop(n[rank0]:            ^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2553, in _inner_training_loopn[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)n[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 3730, in training_stepn[rank0]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/extras/profiling.py", line 87, in wrappern[rank0]:     return func(self, *args, **kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1187, in compute_lossn[rank0]:     return self.compute_liger_loss(model, inputs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1160, in compute_liger_lossn[rank0]:     loss, metrics = self.liger_grpo_loss(n[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impln[rank0]:     return self._call_impl(*args, **kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impln[rank0]:     return forward_call(*args, **kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/liger_kernel/chunked_loss/grpo_loss.py", line 249, in forwardn[rank0]:     return LigerFusedLinearGRPOFunction.apply(n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 575, in applyn[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/liger_kernel/chunked_loss/grpo_loss.py", line 142, in forwardn[rank0]:     return super().forward(n[rank0]:            ^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/liger_kernel/chunked_loss/fused_linear_ppo.py", line 219, in forwardn[rank0]:     accumulate_chunk(n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/liger_kernel/chunked_loss/fused_linear_ppo.py", line 132, in accumulate_chunkn[rank0]:     (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(n[rank0]:                                                                                            ^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fnn[rank0]:     return fn(*args, **kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__n[rank0]:     return self._torchdynamo_orig_callable(n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1164, in __call__n[rank0]:     result = self._inner_convert(n[rank0]:              ^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__n[rank0]:     return _compile(n[rank0]:            ^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compilen[rank0]:     guarded_code = compile_inner(code, one_graph, hooks, transform)n[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_innern[rank0]:     return _compile_inner(code, one_graph, hooks, transform)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_utils_internal.py", line 95, in wrapper_functionn[rank0]:     return function(*args, **kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_innern[rank0]:     out_code = transform_code_object(code, transform)n[rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_objectn[rank0]:     transformations(instructions, code_options)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fnn[rank0]:     return fn(*args, **kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transformn[rank0]:     tracer.run()n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in runn[rank0]:     super().run()n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in runn[rank0]:     while self.step():n[rank0]:           ^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in stepn[rank0]:     self.dispatch_table[inst.opcode](self, inst)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrappern[rank0]:     return inner_fn(self, inst)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALLn[rank0]:     self._call(inst)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _calln[rank0]:     self.call_function(fn, args, kwargs)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_functionn[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]n[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_functionn[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_returnn[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_calln[rank0]:     return cls.inline_call_(parent, func, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_n[rank0]:     tracer.run()n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in runn[rank0]:     while self.step():n[rank0]:           ^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in stepn[rank0]:     self.dispatch_table[inst.opcode](self, inst)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrappern[rank0]:     return inner_fn(self, inst)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALLn[rank0]:     self._call(inst)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _calln[rank0]:     self.call_function(fn, args, kwargs)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_functionn[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]n[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 317, in call_functionn[rank0]:     return super().call_function(tx, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_functionn[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_returnn[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_calln[rank0]:     return cls.inline_call_(parent, func, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_n[rank0]:     tracer.run()n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in runn[rank0]:     while self.step():n[rank0]:           ^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in stepn[rank0]:     self.dispatch_table[inst.opcode](self, inst)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrappern[rank0]:     return inner_fn(self, inst)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EXn[rank0]:     self.call_function(fn, argsvars.items, kwargsvars)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_functionn[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]n[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 317, in call_functionn[rank0]:     return super().call_function(tx, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_functionn[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_returnn[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_calln[rank0]:     return cls.inline_call_(parent, func, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_n[rank0]:     tracer.run()n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in runn[rank0]:     while self.step():n[rank0]:           ^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in stepn[rank0]:     self.dispatch_table[inst.opcode](self, inst)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrappern[rank0]:     return inner_fn(self, inst)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EXn[rank0]:     self.call_function(fn, argsvars.items, kwargsvars)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_functionn[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]n[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 858, in call_functionn[rank0]:     return self.func.call_function(tx, merged_args, merged_kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 317, in call_functionn[rank0]:     return super().call_function(tx, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_functionn[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_returnn[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_calln[rank0]:     return cls.inline_call_(parent, func, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_n[rank0]:     tracer.run()n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in runn[rank0]:     while self.step():n[rank0]:           ^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in stepn[rank0]:     self.dispatch_table[inst.opcode](self, inst)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrappern[rank0]:     return inner_fn(self, inst)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALLn[rank0]:     self._call(inst)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _calln[rank0]:     self.call_function(fn, args, kwargs)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_functionn[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]n[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/misc.py", line 1022, in call_functionn[rank0]:     return self.obj.call_method(tx, self.name, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/misc.py", line 778, in call_methodn[rank0]:     .call_function(tx, args, kwargs)n[rank0]:      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 317, in call_functionn[rank0]:     return super().call_function(tx, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_functionn[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_returnn[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_calln[rank0]:     return cls.inline_call_(parent, func, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_n[rank0]:     tracer.run()n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in runn[rank0]:     while self.step():n[rank0]:           ^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in stepn[rank0]:     self.dispatch_table[inst.opcode](self, inst)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrappern[rank0]:     return inner_fn(self, inst)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALLn[rank0]:     self._call(inst)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _calln[rank0]:     self.call_function(fn, args, kwargs)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_functionn[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]n[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/torch.py", line 953, in call_functionn[rank0]:     tensor_variable = wrap_fx_proxy(n[rank0]:                       ^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxyn[rank0]:     return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_clsn[rank0]:     return _wrap_fx_proxy(n[rank0]:            ^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxyn[rank0]:     example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)n[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2536, in get_fake_valuen[rank0]:     raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from Nonen[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2471, in get_fake_valuen[rank0]:     ret_val = wrap_fake_exception(n[rank0]:               ^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exceptionn[rank0]:     return fn()n[rank0]:            ^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2472, in <lambda>n[rank0]:     lambda: run_node(tx.output, node, args, kwargs, nnmodule)n[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2604, in run_noden[rank0]:     raise RuntimeError(make_error_message(e)).with_traceback(n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2586, in run_noden[rank0]:     return node.goal(*args, **kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_prims_common/wrappers.py", line 289, in _fnn[rank0]:     result = fn(*args, is_out=(out is just not None), **kwargs)  # type: ignore[arg-type]n[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_decomp/decompositions.py", line 4444, in matmuln[rank0]:     return torch.ops.aten._unsafe_view(t1_folded.mv(t2), output_shape)n[rank0]:                                        ^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/_stats.py", line 21, in wrappern[rank0]:     return fn(*args, **kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__n[rank0]:     return self.dispatch(func, types, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatchn[rank0]:     return self._cached_dispatch_impl(func, types, args, kwargs)n[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impln[rank0]:     output = self._dispatch_impl(func, types, args, kwargs)n[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 2290, in _dispatch_impln[rank0]:     decomposition_table[func](*args, **kwargs)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_prims_common/wrappers.py", line 291, in _fnn[rank0]:     result = fn(*args, **kwargs)n[rank0]:              ^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_decomp/decompositions.py", line 83, in innern[rank0]:     r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))n[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_decomp/decompositions.py", line 4336, in mvn[rank0]:     torch._check(n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/__init__.py", line 1656, in _checkn[rank0]:     _check_with(RuntimeError, cond, message)n[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/__init__.py", line 1638, in _check_withn[rank0]:     raise error_type(message_evaluated)n[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method matmul of type object at 0x7f2e2a41ff00>(*(GradTrackingTensor(lvl=1, value=n[rank0]:     FakeTensor(..., device="cuda:0", size=(1, s0, 896), dtype=torch.bfloat16,n[rank0]:                requires_grad=True)n[rank0]: ), GradTrackingTensor(lvl=1, value=n[rank0]:     FakeTensor(..., device="cuda:0", size=(0,), dtype=torch.bfloat16,n[rank0]:                requires_grad=True)n[rank0]: )), **{}):n[rank0]: size mismatch, got input (s0x896), vec (0)n

n

Does liger GRPO support multi-gpu training with deepspeed zero3?

n”,”updatedAt”:”2025-06-10T09:08:19.798Z”,”writer”:{“_id”:”60ad12f355f970745d4ec28f”,”avatarUrl”:”https://cdn-avatars.huggingface.co/v1/production/uploads/60ad12f355f970745d4ec28f/TCJiFQck_tzWc_y4-nmTB.png”,”fullname”:”Seungwoo Ryu”,”name”:”tryumanshow”,”type”:”user”,”isPro”:false,”isHf”:false,”isHfAdmin”:false,”isMod”:false,”followerCount”:13}},”numEdits”:0,”identifiedLanguage”:{“language”:”en”,”probability”:0.27987757325172424},”editors”:[“tryumanshow”],”editorAvatarUrls”:[“https://cdn-avatars.huggingface.co/v1/production/uploads/60ad12f355f970745d4ec28f/TCJiFQck_tzWc_y4-nmTB.png”],”reactions”:[{“reaction”:”➕”,”users”:[“Shaleen123″],”count”:1}],”isReport”:false}}],”status”:”open”,”isReport”:false,”pinned”:false,”locked”:false,”collection”:”canonical_blogs”},”contextAuthors”:[“shisahni”,”kashif”,”smohammadi”,”ShirinYamani”,”m0m0chen”,”liberty4321″],”primaryEmailConfirmed”:false,”discussionRole”:0,”acceptLanguages”:[“*”],”withThread”:true,”cardDisplay”:false,”repoDiscussionsLocked”:false}”>



Source link

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x