diff --git a/Common/Common.vcxproj b/Common/Common.vcxproj
index 7e61f89776..e0d85faeaf 100644
--- a/Common/Common.vcxproj
+++ b/Common/Common.vcxproj
@@ -529,6 +529,7 @@
+
diff --git a/Common/Common.vcxproj.filters b/Common/Common.vcxproj.filters
index 65b6eaccab..d701c1d1c6 100644
--- a/Common/Common.vcxproj.filters
+++ b/Common/Common.vcxproj.filters
@@ -415,6 +415,9 @@
GPU\Vulkan
+
+ Thread
+
diff --git a/Common/Thread/Barrier.h b/Common/Thread/Barrier.h
new file mode 100644
index 0000000000..11e55cdb00
--- /dev/null
+++ b/Common/Thread/Barrier.h
@@ -0,0 +1,31 @@
+#pragma once
+
+#include
+#include
+
+// Similar to C++20's std::barrier
+class CountingBarrier {
+public:
+ CountingBarrier(size_t count) : threadCount_(count) {}
+
+ void Arrive() {
+ std::unique_lock lk(m);
+ counter++;
+ waiting++;
+ cv.wait(lk, [&] {return counter >= threadCount_; });
+ cv.notify_one();
+ waiting--;
+ if (waiting == 0) {
+ // Reset so it can be re-used.
+ counter = 0;
+ }
+ lk.unlock();
+ }
+
+private:
+ std::mutex m;
+ std::condition_variable cv;
+ size_t counter = 0;
+ size_t waiting = 0;
+ size_t threadCount_;
+};
diff --git a/unittest/TestThreadManager.cpp b/unittest/TestThreadManager.cpp
index 3445a33083..6c0abb90b5 100644
--- a/unittest/TestThreadManager.cpp
+++ b/unittest/TestThreadManager.cpp
@@ -1,11 +1,14 @@
#include "Common/Log.h"
#include "Common/TimeUtil.h"
+#include "Common/Thread/Barrier.h"
#include "Common/Thread/ThreadManager.h"
#include "Common/Thread/Channel.h"
#include "Common/Thread/Promise.h"
#include "Common/Thread/ParallelLoop.h"
#include "Common/Thread/ThreadUtil.h"
+#include "UnitTest.h"
+
struct ResultObject {
bool ok;
};
@@ -56,16 +59,70 @@ bool TestParallelLoop(ThreadManager *threadMan) {
return true;
}
+// This is some ugly stuff but realistic.
+const size_t THREAD_COUNT = 6; // Must match the number of threads in TestMultithreadedScheduling
+const size_t ITERATIONS = 100000;
+
+static std::atomic g_atomicCounter;
+static ThreadManager *g_threadMan;
+static CountingBarrier g_barrier(THREAD_COUNT + 1);
+
+class IncrementTask : public Task {
+public:
+ IncrementTask(TaskType type) : type_(type) {}
+ ~IncrementTask() {}
+ virtual TaskType Type() const { return type_; }
+ virtual void Run() {
+ g_atomicCounter++;
+ }
+private:
+ TaskType type_;
+};
+
+void ThreadFunc() {
+ for (int i = 0; i < ITERATIONS; i++) {
+ g_threadMan->EnqueueTask(new IncrementTask((i & 1) ? TaskType::CPU_COMPUTE : TaskType::IO_BLOCKING));
+ }
+ g_barrier.Arrive();
+}
+
+bool TestMultithreadedScheduling() {
+ g_atomicCounter = 0;
+ std::thread thread1(ThreadFunc);
+ std::thread thread2(ThreadFunc);
+ std::thread thread3(ThreadFunc);
+ std::thread thread4(ThreadFunc);
+ std::thread thread5(ThreadFunc);
+ std::thread thread6(ThreadFunc);
+
+ // Just testing the barrier
+ g_barrier.Arrive();
+ // OK, all are done.
+
+ EXPECT_EQ_INT(g_atomicCounter, THREAD_COUNT * ITERATIONS);
+
+ thread1.join();
+ thread2.join();
+ thread3.join();
+ thread4.join();
+ thread5.join();
+ thread6.join();
+
+ return true;
+}
+
bool TestThreadManager() {
ThreadManager manager;
manager.Init(8, 1);
+ g_threadMan = &manager;
+
Promise *object(Promise::Spawn(&manager, &ResultProducer, TaskType::IO_BLOCKING));
if (!TestParallelLoop(&manager)) {
return false;
}
- sleep_ms(1000);
+ sleep_ms(100);
ResultObject *result = object->BlockUntilReady();
if (result) {
@@ -78,5 +135,9 @@ bool TestThreadManager() {
return false;
}
+ if (!TestMultithreadedScheduling()) {
+ return false;
+ }
+
return true;
}
diff --git a/unittest/UnitTest.h b/unittest/UnitTest.h
index e953925d39..2267a7a87f 100644
--- a/unittest/UnitTest.h
+++ b/unittest/UnitTest.h
@@ -2,7 +2,7 @@
#define EXPECT_TRUE(a) if (!(a)) { printf("%s:%i: Test Fail\n", __FUNCTION__, __LINE__); return false; }
#define EXPECT_FALSE(a) if ((a)) { printf("%s:%i: Test Fail\n", __FUNCTION__, __LINE__); return false; }
-#define EXPECT_EQ_INT(a, b) if ((a) != (b)) { printf("%s:%i: Test Fail\n%d\nvs\n%d\n", __FUNCTION__, __LINE__, a, b); return false; }
+#define EXPECT_EQ_INT(a, b) if ((a) != (b)) { printf("%s:%i: Test Fail\n%d\nvs\n%d\n", __FUNCTION__, __LINE__, (int)(a), (int)(b)); return false; }
#define EXPECT_EQ_HEX(a, b) if ((a) != (b)) { printf("%s:%i: Test Fail\n%x\nvs\n%x\n", __FUNCTION__, __LINE__, a, b); return false; }
#define EXPECT_EQ_FLOAT(a, b) if ((a) != (b)) { printf("%s:%i: Test Fail\n%f\nvs\n%f\n", __FUNCTION__, __LINE__, a, b); return false; }
#define EXPECT_APPROX_EQ_FLOAT(a, b) if (fabsf((a)-(b))>0.00001f) { printf("%s:%i: Test Fail\n%f\nvs\n%f\n", __FUNCTION__, __LINE__, a, b); /*return false;*/ }