From cf9d4eb7d421ed640a49099922c3b272776566e8 Mon Sep 17 00:00:00 2001 From: zhangyihuiben Date: Fri, 23 Jan 2026 11:27:02 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90master=E3=80=91=E3=80=90bugfix?= =?UTF-8?q?=E3=80=91Tensor=20does=20not=20support=20str=20type,=20will=20r?= =?UTF-8?q?eturn=20numpy.ndarray?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_dataloader/test_adgen_dataloader.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/st/test_ut/test_dataset/test_dataloader/test_adgen_dataloader.py b/tests/st/test_ut/test_dataset/test_dataloader/test_adgen_dataloader.py index 9e47c6e8c..efa6867c5 100644 --- a/tests/st/test_ut/test_dataset/test_dataloader/test_adgen_dataloader.py +++ b/tests/st/test_ut/test_dataset/test_dataloader/test_adgen_dataloader.py @@ -16,9 +16,9 @@ import os import unittest import tempfile -from mindformers.dataset.dataloader.adgen_dataloader import ADGenDataset, ADGenDataLoader import pytest from tests.st.test_ut.test_dataset.get_test_data import get_adgen_data +from mindformers.dataset.dataloader.adgen_dataloader import ADGenDataset, ADGenDataLoader class TestAdgenDataloader(unittest.TestCase): @@ -26,13 +26,17 @@ class TestAdgenDataloader(unittest.TestCase): @classmethod def setUpClass(cls): - cls.temp_dir = tempfile.TemporaryDirectory() + cls.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with cls.path = cls.temp_dir.name cls.phase = "train" cls.data_path = os.path.join(cls.path, f"{cls.phase}.json") cls.columns = ["content", "summary"] get_adgen_data(cls.path) + @classmethod + def tearDownClass(cls): + cls.temp_dir.cleanup() + @pytest.mark.level1 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -45,8 +49,8 @@ class TestAdgenDataloader(unittest.TestCase): dataloader = dataloader.batch(1) for item in dataloader: assert len(item) == 2 - assert item[0].asnumpy()[0] == "类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤" - assert item[1].asnumpy()[0] == "宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。" + assert item[0][0] == "类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤" + assert item[1][0] == "宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。" break @pytest.mark.level1 @@ -84,9 +88,13 @@ class TestAdgenDataSet(unittest.TestCase): @classmethod def setUpClass(cls): - cls.temp_dir = tempfile.TemporaryDirectory() + cls.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with cls.path = cls.temp_dir.name + @classmethod + def tearDownClass(cls): + cls.temp_dir.cleanup() + @pytest.mark.level1 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -- Gitee