Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/class/learning/Classification.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import InputPopup from './InputPopup';

export const classes = [
'ai_learning_speech'
];
export const classes = ['ai_learning_speech'];

class Classification {
#type = null;
Expand All @@ -20,9 +18,11 @@ class Classification {

getResult(index) {
const result = this.#popup?.result || [];
const defaultResult = {probability: 0, className: ''};
if(index !== undefined && index > -1) {
return result.find(({className}) => className === this.#labels[index]) || defaultResult;
const defaultResult = { probability: 0, className: '' };
if (index !== undefined && index > -1) {
return (
result.find(({ className }) => className === this.#labels[index]) || defaultResult
);
}
return result[0] || defaultResult;
}
Expand All @@ -36,10 +36,10 @@ class Classification {

openInputPopup() {
this.#popup = new InputPopup({
url: this.#url,
url: this.#url,
labels: this.#labels,
type: this.#type,
recordTime: this.#recordTime
recordTime: this.#recordTime,
});
this.#popup.open();
}
Expand Down
8 changes: 6 additions & 2 deletions src/class/learning/Cluster.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { kmpp } from 'skmeans/kinit';
// import { kmpp } from 'skmeans/dist/node/kinit';
import floor from 'lodash/floor';
import _toNumber from 'lodash/toNumber';
import LearningView from './LearningView';
import Chart from './Chart';
import DataTable from '../DataTable';
Expand Down Expand Up @@ -153,14 +154,17 @@ class Cluster {
this.#trainCallback(1);
this.#isTrained = false;
const { data, select } = this.#table;
const filtered = data.filter(
(row) => !select.flat().some((selected) => !_toNumber(row[selected]))
);
const [attr] = select;

const { centroids, indexes } = kmeans(
data.map((row) => attr.map((i) => parseFloat(row[i]) || 0)),
filtered.map((row) => attr.map((i) => parseFloat(row[i]) || 0)),
this.#trainParam
);
this.#result = {
graphData: convertGraphData(data, centroids, indexes, attr),
graphData: convertGraphData(filtered, centroids, indexes, attr),
centroids,
};
this.#isTrained = true;
Expand Down
11 changes: 7 additions & 4 deletions src/class/learning/DecisionTree.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import _floor from 'lodash/floor';
import _max from 'lodash/max';
import _sum from 'lodash/sum';
import _mean from 'lodash/mean';
import _toNumber from 'lodash/toNumber';
import LearningBase from './LearningBase';
import { DecisionTreeClassifier as DTClassifier } from 'ml-cart';
import Utils from './Utils';
Expand All @@ -24,7 +25,7 @@ class DecisionTree extends LearningBase {
type = 'decisiontree';

init({ name, url, result, table, trainParam }) {
this.name = name;
this.name = name;
this.trainParam = trainParam;
this.result = result;
this.table = table;
Expand Down Expand Up @@ -157,8 +158,10 @@ function getData(testRate = 0.2, data) {
const tempMapCount = {};
const { select = [[0], [1]], data: table, fields } = data;
const [attr, predict] = select;

const dataArray = table
const filtered = table.filter(
(row) => !select.flat().some((selected) => !_toNumber(row[selected]))
);
const dataArray = filtered
.map((row) => ({
x: attr.map((i) => Utils.stringToNumber(i, row[i], tempMap, tempMapCount)),
y: Utils.stringToNumber(predict[0], row[predict[0]], tempMap, tempMapCount),
Expand All @@ -176,7 +179,7 @@ function getData(testRate = 0.2, data) {
select,
fields,
valueMap: { ...tempMap[predict[0]] },
numClass: tempMapCount[predict[0]],
numClass: tempMapCount[predict[0]] || 1,
};
}

Expand Down
7 changes: 5 additions & 2 deletions src/class/learning/LogisticRegression.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import _floor from 'lodash/floor';
import _max from 'lodash/max';
import _sum from 'lodash/sum';
import _mean from 'lodash/mean';
import _toNumber from 'lodash/toNumber';
import LearningBase from './LearningBase';
import Utils from './Utils';

Expand Down Expand Up @@ -182,8 +183,10 @@ function getData(validationRate, testRate, data, trainParam) {
const tempMapCount = {};
const { select = [[0], [1]], data: table, fields } = data;
const [attr, predict] = select;

const dataArray = table
const filtered = table.filter(
(row) => !select.flat().some((selected) => !_toNumber(row[selected]))
);
const dataArray = filtered
.map((row) => ({
x: attr.map((i) => Utils.stringToNumber(i, row[i], tempMap, tempMapCount)),
y: Utils.stringToNumber(predict[0], row[predict[0]], tempMap, tempMapCount),
Expand Down
Loading