Implement input validation for string length and index capacity in similarity_search_addon.cc and similarity_search.c; enhance memory management in search_index function.

This commit is contained in:
seb
2025-04-18 08:57:35 +02:00
parent de950fa11c
commit da5e7476a0
3 changed files with 74 additions and 2 deletions

View File

@@ -141,6 +141,8 @@ void generate_random_string(char *string, int max_len) {
// Create a new search index // Create a new search index
SearchIndex* create_search_index(int capacity) { SearchIndex* create_search_index(int capacity) {
if (capacity <= 0) return NULL;
SearchIndex* index = (SearchIndex*)malloc(sizeof(SearchIndex)); SearchIndex* index = (SearchIndex*)malloc(sizeof(SearchIndex));
if (!index) return NULL; if (!index) return NULL;
@@ -151,6 +153,7 @@ SearchIndex* create_search_index(int capacity) {
} }
index->num_strings = 0; index->num_strings = 0;
index->capacity = capacity;
return index; return index;
} }
@@ -158,6 +161,16 @@ SearchIndex* create_search_index(int capacity) {
int add_string_to_index(SearchIndex* index, const char* string) { int add_string_to_index(SearchIndex* index, const char* string) {
if (!index || !string) return -1; if (!index || !string) return -1;
// Check if we've reached capacity
if (index->num_strings >= index->capacity) {
return -1;
}
// Check if string is too long
if (strlen(string) >= MAX_STRING_LEN) {
return -1;
}
index->strings[index->num_strings] = strdup(string); index->strings[index->num_strings] = strdup(string);
if (!index->strings[index->num_strings]) return -1; if (!index->strings[index->num_strings]) return -1;
@@ -181,6 +194,18 @@ void free_search_index(SearchIndex* index) {
SearchResult* search_index(SearchIndex* index, const char* query, float cutoff, int* num_results) { SearchResult* search_index(SearchIndex* index, const char* query, float cutoff, int* num_results) {
if (!index || !query || !num_results) return NULL; if (!index || !query || !num_results) return NULL;
// Validate input string length
if (strlen(query) >= MAX_STRING_LEN) {
*num_results = 0;
return NULL;
}
// Validate cutoff
if (cutoff < 0.0f || cutoff > 1.0f) {
*num_results = 0;
return NULL;
}
// Allocate temporary array for results // Allocate temporary array for results
SearchResult* temp_results = (SearchResult*)malloc(index->num_strings * sizeof(SearchResult)); SearchResult* temp_results = (SearchResult*)malloc(index->num_strings * sizeof(SearchResult));
if (!temp_results) return NULL; if (!temp_results) return NULL;
@@ -192,7 +217,16 @@ SearchResult* search_index(SearchIndex* index, const char* query, float cutoff,
float similarity = calculate_similarity(query, index->strings[i], cutoff); float similarity = calculate_similarity(query, index->strings[i], cutoff);
if (similarity >= cutoff) { if (similarity >= cutoff) {
temp_results[*num_results].string = index->strings[i]; // Store a copy of the string in the result
temp_results[*num_results].string = strdup(index->strings[i]);
if (!temp_results[*num_results].string) {
// Free any already allocated strings on error
for (int j = 0; j < *num_results; j++) {
free(temp_results[j].string);
}
free(temp_results);
return NULL;
}
temp_results[*num_results].similarity = similarity; temp_results[*num_results].similarity = similarity;
(*num_results)++; (*num_results)++;
} }
@@ -204,6 +238,10 @@ SearchResult* search_index(SearchIndex* index, const char* query, float cutoff,
// Allocate final result array with exact size // Allocate final result array with exact size
SearchResult* results = (SearchResult*)malloc(*num_results * sizeof(SearchResult)); SearchResult* results = (SearchResult*)malloc(*num_results * sizeof(SearchResult));
if (!results) { if (!results) {
// Free all strings in temp_results
for (int i = 0; i < *num_results; i++) {
free(temp_results[i].string);
}
free(temp_results); free(temp_results);
return NULL; return NULL;
} }
@@ -217,5 +255,11 @@ SearchResult* search_index(SearchIndex* index, const char* query, float cutoff,
// Free the search results // Free the search results
void free_search_results(SearchResult* results, int num_results) { void free_search_results(SearchResult* results, int num_results) {
if (!results) return;
// Free all strings in the results
for (int i = 0; i < num_results; i++) {
free(results[i].string);
}
free(results); free(results);
} }

View File

@@ -14,6 +14,7 @@ extern "C" {
typedef struct { typedef struct {
char **strings; char **strings;
int num_strings; int num_strings;
int capacity;
} SearchIndex; } SearchIndex;
// Structure to hold a search result // Structure to hold a search result

View File

@@ -66,7 +66,24 @@ Napi::Value SearchIndexWrapper::AddString(const Napi::CallbackInfo& info) {
} }
std::string str = info[0].As<Napi::String>().Utf8Value(); std::string str = info[0].As<Napi::String>().Utf8Value();
// Check if string is too long
if (str.length() >= MAX_STRING_LEN) {
Napi::Error::New(env, "String too long").ThrowAsJavaScriptException();
return env.Null();
}
// Check if we've reached capacity
if (this->index_->num_strings >= this->index_->capacity) {
Napi::Error::New(env, "Search index capacity exceeded").ThrowAsJavaScriptException();
return env.Null();
}
int result = add_string_to_index(this->index_, str.c_str()); int result = add_string_to_index(this->index_, str.c_str());
if (result != 0) {
Napi::Error::New(env, "Failed to add string to index").ThrowAsJavaScriptException();
return env.Null();
}
return Napi::Number::New(env, result); return Napi::Number::New(env, result);
} }
@@ -81,10 +98,20 @@ Napi::Value SearchIndexWrapper::Search(const Napi::CallbackInfo& info) {
} }
std::string query = info[0].As<Napi::String>().Utf8Value(); std::string query = info[0].As<Napi::String>().Utf8Value();
float cutoff = 0.2f; // Default cutoff
// Check if query string is too long
if (query.length() >= MAX_STRING_LEN) {
Napi::Error::New(env, "Query string too long").ThrowAsJavaScriptException();
return env.Null();
}
float cutoff = 0.2f; // Default cutoff
if (info.Length() > 1 && info[1].IsNumber()) { if (info.Length() > 1 && info[1].IsNumber()) {
cutoff = info[1].As<Napi::Number>().FloatValue(); cutoff = info[1].As<Napi::Number>().FloatValue();
if (cutoff < 0.0f || cutoff > 1.0f) {
Napi::Error::New(env, "Cutoff must be between 0 and 1").ThrowAsJavaScriptException();
return env.Null();
}
} }
int num_results = 0; int num_results = 0;