Skip to content

Commit f20811e

Browse files
committed
Adapt tests
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent be31c28 commit f20811e

File tree

9 files changed

+85
-23
lines changed

9 files changed

+85
-23
lines changed

core/config/model_config.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package config
22

33
import (
4+
"fmt"
45
"os"
56
"regexp"
67
"slices"
@@ -475,7 +476,7 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
475476
cfg.syncKnownUsecasesFromString()
476477
}
477478

478-
func (c *ModelConfig) Validate() bool {
479+
func (c *ModelConfig) Validate() (bool, error) {
479480
downloadedFileNames := []string{}
480481
for _, f := range c.DownloadFiles {
481482
downloadedFileNames = append(downloadedFileNames, f.Filename)
@@ -489,17 +490,20 @@ func (c *ModelConfig) Validate() bool {
489490
}
490491
if strings.HasPrefix(n, string(os.PathSeparator)) ||
491492
strings.Contains(n, "..") {
492-
return false
493+
return false, fmt.Errorf("invalid file path: %s", n)
493494
}
494495
}
495496

496497
if c.Backend != "" {
497498
// a regex that checks that is a string name with no special characters, except '-' and '_'
498499
re := regexp.MustCompile(`^[a-zA-Z0-9-_]+$`)
499-
return re.MatchString(c.Backend)
500+
if !re.MatchString(c.Backend) {
501+
return false, fmt.Errorf("invalid backend name: %s", c.Backend)
502+
}
503+
return true, nil
500504
}
501505

502-
return true
506+
return true, nil
503507
}
504508

505509
func (c *ModelConfig) HasTemplate() bool {

core/config/model_config_loader.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ func (bcl *ModelConfigLoader) LoadMultipleModelConfigsSingleFile(file string, op
169169
}
170170

171171
for _, cc := range c {
172-
if cc.Validate() {
172+
if valid, _ := cc.Validate(); valid {
173173
bcl.configs[cc.Name] = *cc
174174
}
175175
}
@@ -184,7 +184,7 @@ func (bcl *ModelConfigLoader) ReadModelConfig(file string, opts ...ConfigLoaderO
184184
return fmt.Errorf("ReadModelConfig cannot read config file %q: %w", file, err)
185185
}
186186

187-
if c.Validate() {
187+
if valid, _ := c.Validate(); valid {
188188
bcl.configs[c.Name] = *c
189189
} else {
190190
return fmt.Errorf("config is not valid")
@@ -362,7 +362,7 @@ func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...Conf
362362
log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadModelConfigsFromPath cannot read config file")
363363
continue
364364
}
365-
if c.Validate() {
365+
if valid, _ := c.Validate(); valid {
366366
bcl.configs[c.Name] = *c
367367
} else {
368368
log.Error().Err(err).Str("Name", c.Name).Msgf("config is not valid")

core/config/model_config_test.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ known_usecases:
2828
config, err := readModelConfigFromFile(tmp.Name())
2929
Expect(err).To(BeNil())
3030
Expect(config).ToNot(BeNil())
31-
Expect(config.Validate()).To(BeFalse())
31+
valid, err := config.Validate()
32+
Expect(err).To(HaveOccurred())
33+
Expect(valid).To(BeFalse())
3234
Expect(config.KnownUsecases).ToNot(BeNil())
3335
})
3436
It("Test Validate", func() {
@@ -46,7 +48,9 @@ parameters:
4648
Expect(config).ToNot(BeNil())
4749
// two configs in config.yaml
4850
Expect(config.Name).To(Equal("bar-baz"))
49-
Expect(config.Validate()).To(BeTrue())
51+
valid, err := config.Validate()
52+
Expect(err).To(BeNil())
53+
Expect(valid).To(BeTrue())
5054

5155
// download https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml
5256
httpClient := http.Client{}
@@ -63,7 +67,9 @@ parameters:
6367
Expect(config).ToNot(BeNil())
6468
// two configs in config.yaml
6569
Expect(config.Name).To(Equal("hermes-2-pro-mistral"))
66-
Expect(config.Validate()).To(BeTrue())
70+
valid, err = config.Validate()
71+
Expect(err).To(BeNil())
72+
Expect(valid).To(BeTrue())
6773
})
6874
})
6975
It("Properly handles backend usecase matching", func() {

core/gallery/models.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver
259259
return nil, fmt.Errorf("failed to unmarshal updated config YAML: %v", err)
260260
}
261261

262-
if !modelConfig.Validate() {
263-
return nil, fmt.Errorf("failed to validate updated config YAML")
262+
if valid, err := modelConfig.Validate(); !valid {
263+
return nil, fmt.Errorf("failed to validate updated config YAML: %v", err)
264264
}
265265

266266
err = os.WriteFile(configFilePath, updatedConfigYAML, 0600)

core/http/app_test.go

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,14 +523,19 @@ var _ = Describe("API test", func() {
523523
backend: llama-cpp
524524
description: Test model imported from file URI
525525
parameters:
526-
model: /path/to/model.gguf
526+
model: path/to/model.gguf
527527
temperature: 0.7
528528
`
529529
testYamlFile = filepath.Join(tmpdir, "test-import.yaml")
530530
err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644)
531531
Expect(err).ToNot(HaveOccurred())
532532
})
533533

534+
AfterEach(func() {
535+
err := os.Remove(testYamlFile)
536+
Expect(err).ToNot(HaveOccurred())
537+
})
538+
534539
It("should import model from file:// URI pointing to local YAML config", func() {
535540
importReq := schema.ImportModelRequest{
536541
URI: "file://" + testYamlFile,
@@ -579,6 +584,53 @@ parameters:
579584
Expect(err.Error()).To(ContainSubstring("failed to discover model config"))
580585
})
581586
})
587+
588+
Context("Importing models from URI can't point to absolute paths", func() {
589+
var testYamlFile string
590+
591+
BeforeEach(func() {
592+
// Create a test YAML config file
593+
yamlContent := `name: test-import-model
594+
backend: llama-cpp
595+
description: Test model imported from file URI
596+
parameters:
597+
model: /path/to/model.gguf
598+
temperature: 0.7
599+
`
600+
testYamlFile = filepath.Join(tmpdir, "test-import.yaml")
601+
err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644)
602+
Expect(err).ToNot(HaveOccurred())
603+
})
604+
605+
AfterEach(func() {
606+
err := os.Remove(testYamlFile)
607+
Expect(err).ToNot(HaveOccurred())
608+
})
609+
610+
It("should fail to import model from file:// URI pointing to local YAML config", func() {
611+
importReq := schema.ImportModelRequest{
612+
URI: "file://" + testYamlFile,
613+
Preferences: json.RawMessage(`{}`),
614+
}
615+
616+
var response schema.GalleryResponse
617+
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
618+
Expect(err).ToNot(HaveOccurred())
619+
Expect(response.ID).ToNot(BeEmpty())
620+
621+
uuid := response.ID
622+
resp := map[string]interface{}{}
623+
Eventually(func() bool {
624+
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
625+
resp = response
626+
return response["processed"].(bool)
627+
}, "360s", "10s").Should(Equal(true))
628+
629+
// Check that the model was imported successfully
630+
Expect(resp["message"]).To(ContainSubstring("error"))
631+
Expect(resp["error"]).ToNot(BeNil())
632+
})
633+
})
582634
})
583635

584636
Context("Model gallery", func() {

core/http/endpoints/localai/edit_model.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
135135
}
136136

137137
// Validate the configuration
138-
if !req.Validate() {
138+
if valid, _ := req.Validate(); !valid {
139139
response := ModelResponse{
140140
Success: false,
141141
Error: "Validation failed",

core/http/endpoints/localai/import_model.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
148148
modelConfig.SetDefaults()
149149

150150
// Validate the configuration
151-
if !modelConfig.Validate() {
151+
if valid, _ := modelConfig.Validate(); !valid {
152152
response := ModelResponse{
153153
Success: false,
154154
Error: "Invalid configuration",

core/http/endpoints/openai/realtime_model.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfig
112112
return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
113113
}
114114

115-
if !cfgVAD.Validate() {
115+
if valid, _ := cfgVAD.Validate(); !valid {
116116
return nil, nil, fmt.Errorf("failed to validate config: %w", err)
117117
}
118118

@@ -128,7 +128,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfig
128128
return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
129129
}
130130

131-
if !cfgSST.Validate() {
131+
if valid, _ := cfgSST.Validate(); !valid {
132132
return nil, nil, fmt.Errorf("failed to validate config: %w", err)
133133
}
134134

@@ -155,7 +155,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
155155
return nil, fmt.Errorf("failed to load backend config: %w", err)
156156
}
157157

158-
if !cfgVAD.Validate() {
158+
if valid, _ := cfgVAD.Validate(); !valid {
159159
return nil, fmt.Errorf("failed to validate config: %w", err)
160160
}
161161

@@ -172,7 +172,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
172172
return nil, fmt.Errorf("failed to load backend config: %w", err)
173173
}
174174

175-
if !cfgSST.Validate() {
175+
if valid, _ := cfgSST.Validate(); !valid {
176176
return nil, fmt.Errorf("failed to validate config: %w", err)
177177
}
178178

@@ -191,7 +191,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
191191
return nil, fmt.Errorf("failed to load backend config: %w", err)
192192
}
193193

194-
if !cfgAnyToAny.Validate() {
194+
if valid, _ := cfgAnyToAny.Validate(); !valid {
195195
return nil, fmt.Errorf("failed to validate config: %w", err)
196196
}
197197

@@ -218,7 +218,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
218218
return nil, fmt.Errorf("failed to load backend config: %w", err)
219219
}
220220

221-
if !cfgLLM.Validate() {
221+
if valid, _ := cfgLLM.Validate(); !valid {
222222
return nil, fmt.Errorf("failed to validate config: %w", err)
223223
}
224224

@@ -228,7 +228,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model
228228
return nil, fmt.Errorf("failed to load backend config: %w", err)
229229
}
230230

231-
if !cfgTTS.Validate() {
231+
if valid, _ := cfgTTS.Validate(); !valid {
232232
return nil, fmt.Errorf("failed to validate config: %w", err)
233233
}
234234

core/http/middleware/request.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.
475475
}
476476
}
477477

478-
if config.Validate() {
478+
if valid, _ := config.Validate(); valid {
479479
return nil
480480
}
481481
return fmt.Errorf("unable to validate configuration after merging")

0 commit comments

Comments
 (0)