Skip to content

Commit fe9923d

Browse files
authored
INC Bench update (#1366)
1 parent 23c585e commit fe9923d

File tree

106 files changed

+1986
-648
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+1986
-648
lines changed

conda_meta/full/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ requirements:
4141
- sqlalchemy ==1.4.27
4242
- alembic ==1.7.7
4343
- cython
44+
- pywin32 # [win]
4445
test:
4546
imports:
4647
- neural_compressor

neural_compressor/ux/components/configuration_wizard/configuration_parser.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
"""Configuration type parser."""
1616
import json
1717
from collections.abc import Iterable
18+
from copy import deepcopy
1819
from typing import Any, Dict, List, Type, Union
1920

2021
from neural_compressor.ux.utils.exceptions import ClientErrorException
2122
from neural_compressor.ux.utils.hw_info import HWInfo
23+
from neural_compressor.ux.utils.logger import log
2224
from neural_compressor.ux.utils.utils import parse_bool_value
2325

2426

@@ -92,8 +94,9 @@ def __init__(self) -> None:
9294
"bool": bool,
9395
}
9496

95-
def parse(self, data: dict) -> dict:
97+
def parse(self, input_data: dict) -> dict:
9698
"""Parse configuration."""
99+
data = deepcopy(input_data)
97100
transforms_data = data.get("transform", None)
98101
if transforms_data is not None:
99102
data.update({"transform": self.parse_transforms(transforms_data)})
@@ -110,7 +113,9 @@ def parse(self, data: dict) -> dict:
110113

111114
metric_params = data.get("metric_param", None)
112115
if metric_params and isinstance(metric_params, dict):
113-
data["metric_param"] = self.parse_metric(metric_params)
116+
parsed_metric_params = self.parse_metric(metric_params)
117+
118+
data.update({"metric_param": parsed_metric_params})
114119

115120
if "tuning" in data.keys():
116121
data["tuning"] = parse_bool_value(data["tuning"])
@@ -227,13 +232,15 @@ def parse_metric(self, metric_data: dict) -> dict:
227232
for param_name, param_value in metric_data.items():
228233
if isinstance(param_value, dict):
229234
parsed_data.update({param_name: self.parse_metric(param_value)})
230-
elif isinstance(param_value, str):
235+
elif isinstance(param_value, str) or isinstance(param_value, int):
231236
if param_value == "":
232237
continue
233238
param_type = self.get_param_type("metric", param_name)
234239
if param_type is None:
240+
log.debug("Could not find param type.")
235241
continue
236-
parsed_data.update({param_name: self.parse_value(param_value, param_type)})
242+
parsed_value = self.parse_value(param_value, param_type)
243+
parsed_data.update({param_name: parsed_value})
237244
return parsed_data
238245

239246
def get_param_type(
@@ -313,6 +320,8 @@ def normalize_string_list(string_list: str, required_type: Union[Type, List[Type
313320
if not isinstance(string_list, str):
314321
return string_list
315322
if isinstance(required_type, list):
323+
string_list = string_list.replace("(", "[")
324+
string_list = string_list.replace(")", "]")
316325
while not string_list.startswith("[["):
317326
string_list = "[" + string_list
318327
while not string_list.endswith("]]"):

neural_compressor/ux/components/configuration_wizard/get_boundary_nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def get_boundary_nodes(data: Dict[str, Any]) -> None:
4848
try:
4949
model = model_repository.get_model(model_path)
5050
except NotFoundException:
51+
log.debug(f"Could not get model instance for {model_path}")
5152
supported_frameworks = model_repository.get_frameworks()
5253
raise ClientErrorException(
5354
f"Framework for specified model is not yet supported. "

neural_compressor/ux/components/db_manager/db_models/model.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""The Model class."""
1616
import json
17-
from typing import Any, List
17+
from typing import Any, List, Optional
1818

1919
from sqlalchemy import Boolean, Column, DateTime, Float, ForeignKey, Integer, String
2020
from sqlalchemy.orm import relationship, session
@@ -54,6 +54,17 @@ class Model(Base):
5454
"Optimization",
5555
back_populates="optimized_model",
5656
primaryjoin="Optimization.optimized_model_id == Model.id",
57+
)
58+
59+
benchmarks: Any = relationship(
60+
"Benchmark",
61+
back_populates="model",
62+
cascade="all, delete",
63+
)
64+
65+
profilings: Any = relationship(
66+
"Profiling",
67+
back_populates="model",
5768
cascade="all, delete",
5869
)
5970

@@ -147,6 +158,26 @@ def list(db_session: session.Session, project_id: int) -> dict:
147158
)
148159
return {"models": models}
149160

161+
@staticmethod
162+
def delete_model(
163+
db_session: session.Session,
164+
model_id: int,
165+
model_name: str,
166+
) -> Optional[int]:
167+
"""Remove model from database."""
168+
model = (
169+
db_session.query(Model)
170+
.filter(Model.id == model_id)
171+
.filter(Model.name == model_name)
172+
.one_or_none()
173+
)
174+
if model is None:
175+
return None
176+
db_session.delete(model)
177+
db_session.flush()
178+
179+
return int(model.id)
180+
150181
@staticmethod
151182
def build_info(model: Any) -> dict:
152183
"""Get model info."""

neural_compressor/ux/components/db_manager/db_models/optimization.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
# pylint: disable=no-member
1616
"""The Optimization class."""
1717
import json
18-
from typing import Any, Dict, List, Optional
18+
from typing import Any, Dict, List, Optional, Union
1919

2020
from sqlalchemy import DDL, Column, DateTime, ForeignKey, Integer, String, event
21+
from sqlalchemy.engine import Connection
2122
from sqlalchemy.orm import relationship, session
2223
from sqlalchemy.sql import func
2324

@@ -554,6 +555,21 @@ def list(db_session: session.Session, project_id: int) -> dict:
554555
optimizations.append(optimization_info)
555556
return {"optimizations": optimizations}
556557

558+
@staticmethod
559+
def unpin_benchmark(
560+
db_connection: Union[session.Session, Connection],
561+
benchmark_id: int,
562+
) -> None:
563+
"""Unpin benchmark from optimization."""
564+
update_queries = [
565+
f"UPDATE optimization SET performance_benchmark_id=null "
566+
f"WHERE performance_benchmark_id={benchmark_id}",
567+
f"UPDATE optimization SET accuracy_benchmark_id=null "
568+
f"WHERE accuracy_benchmark_id={benchmark_id}",
569+
]
570+
for update_query in update_queries:
571+
db_connection.execute(update_query)
572+
557573
@staticmethod
558574
def build_info(
559575
optimization: Any,

neural_compressor/ux/components/db_manager/db_operations/benchmark_api_interface.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
"""INC Bench Benchmark API interface."""
1717
import os
1818
import shutil
19+
from sqlite3 import Connection
1920
from typing import List, Optional, Union
2021

21-
from sqlalchemy.orm import sessionmaker
22+
from sqlalchemy import event
23+
from sqlalchemy.orm import Mapper, sessionmaker
2224

2325
from neural_compressor.ux.components.benchmark import Benchmarks
2426
from neural_compressor.ux.components.configuration_wizard.configuration_parser import (
@@ -27,6 +29,7 @@
2729
from neural_compressor.ux.components.db_manager.db_manager import DBManager
2830
from neural_compressor.ux.components.db_manager.db_models.benchmark import Benchmark
2931
from neural_compressor.ux.components.db_manager.db_models.benchmark_result import BenchmarkResult
32+
from neural_compressor.ux.components.db_manager.db_models.optimization import Optimization
3033
from neural_compressor.ux.components.db_manager.db_operations.project_api_interface import (
3134
ProjectAPIInterface,
3235
)
@@ -63,6 +66,10 @@ def delete_benchmark(data: dict) -> dict:
6366
benchmark_details = Benchmark.details(db_session, benchmark_id)
6467
project_id = benchmark_details["project_id"]
6568
project_details = ProjectAPIInterface.get_project_details({"id": project_id})
69+
Optimization.unpin_benchmark(
70+
db_connection=db_session,
71+
benchmark_id=benchmark_id,
72+
)
6673
removed_benchmark_id = Benchmark.delete_benchmark(
6774
db_session=db_session,
6875
benchmark_id=benchmark_id,
@@ -455,3 +462,16 @@ def clean_status(status_to_clean: ExecutionStatus) -> dict:
455462
status_to_clean=status_to_clean,
456463
)
457464
return response
465+
466+
467+
@event.listens_for(Benchmark, "before_delete")
468+
def before_delete_benchmark_entry(
469+
mapper: Mapper,
470+
connection: Connection,
471+
benchmark: Benchmark,
472+
) -> None:
473+
"""Clean up benchmark data before remove."""
474+
Optimization.unpin_benchmark(
475+
db_connection=connection,
476+
benchmark_id=benchmark.id,
477+
)

neural_compressor/ux/components/db_manager/db_operations/model_api_interface.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,25 @@ def list_models(data: dict) -> dict:
9696

9797
return models_list
9898

99+
@staticmethod
100+
def delete_model(data: dict) -> dict:
101+
"""Delete model from database."""
102+
try:
103+
model_id: int = int(data.get("id", None))
104+
model_name: str = str(data.get("name", None))
105+
except ValueError:
106+
raise ClientErrorException("Could not parse value.")
107+
except TypeError:
108+
raise ClientErrorException("Missing model id or model name.")
109+
with Session.begin() as db_session:
110+
removed_model_id = Model.delete_model(
111+
db_session=db_session,
112+
model_id=model_id,
113+
model_name=model_name,
114+
)
115+
116+
return {"id": removed_model_id}
117+
99118
@staticmethod
100119
def parse_model_data(data: dict) -> ModelAddParamsInterface:
101120
"""Parse input data for model."""

neural_compressor/ux/components/model/onnxrt/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def guard_requirements_installed(self) -> None:
217217
"""Ensure all requirements are installed."""
218218
check_module("onnx")
219219
check_module("onnxruntime")
220-
if sys.version_info < (3,10): # pragma: no cover
220+
if sys.version_info < (3, 10): # pragma: no cover
221221
check_module("onnxruntime_extensions")
222222

223223
@property

neural_compressor/ux/components/model/repository.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from neural_compressor.ux.components.model.tensorflow.meta_graph import MetaGraphModel
2525
from neural_compressor.ux.components.model.tensorflow.saved_model import SavedModelModel
2626
from neural_compressor.ux.utils.exceptions import NotFoundException
27+
from neural_compressor.ux.utils.logger import log
2728

2829

2930
class ModelRepository:
@@ -43,6 +44,8 @@ def __init__(self) -> None:
4344
def get_model(self, path: str) -> Model:
4445
"""Get Model for given path."""
4546
for model_type in self.model_types:
47+
supports_path = model_type.supports_path(path)
48+
log.debug(f"{model_type.__name__}: {supports_path}")
4649
if model_type.supports_path(path):
4750
return model_type(path)
4851

neural_compressor/ux/components/optimization/execute_optimization.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,20 @@ def execute_optimization(data: Dict[str, Any]) -> dict:
233233
if is_pytorch_script:
234234
optimized_model_data["model_path"] = logs[0]
235235

236+
optimization_data = OptimizationAPIInterface.get_optimization_details(
237+
{
238+
"id": optimization_id,
239+
},
240+
)
241+
if optimization_data["optimized_model"] is not None:
242+
existing_model_id = optimization_data["optimized_model"]["id"]
243+
existing_model_name = optimization_data["optimized_model"]["name"]
244+
ModelAPIInterface.delete_model(
245+
{
246+
"id": existing_model_id,
247+
"name": existing_model_name,
248+
},
249+
)
236250
optimized_model_id = ModelAPIInterface.add_model(optimized_model_data)
237251
OptimizationAPIInterface.update_optimized_model(
238252
{

0 commit comments

Comments
 (0)