diff --git a/similarity_search.c b/similarity_search.c index 00e3fe8..9cb8b33 100644 --- a/similarity_search.c +++ b/similarity_search.c @@ -141,6 +141,8 @@ void generate_random_string(char *string, int max_len) { // Create a new search index SearchIndex* create_search_index(int capacity) { + if (capacity <= 0) return NULL; + SearchIndex* index = (SearchIndex*)malloc(sizeof(SearchIndex)); if (!index) return NULL; @@ -151,6 +153,7 @@ SearchIndex* create_search_index(int capacity) { } index->num_strings = 0; + index->capacity = capacity; return index; } @@ -158,6 +161,16 @@ SearchIndex* create_search_index(int capacity) { int add_string_to_index(SearchIndex* index, const char* string) { if (!index || !string) return -1; + // Check if we've reached capacity + if (index->num_strings >= index->capacity) { + return -1; + } + + // Check if string is too long + if (strlen(string) >= MAX_STRING_LEN) { + return -1; + } + index->strings[index->num_strings] = strdup(string); if (!index->strings[index->num_strings]) return -1; @@ -181,6 +194,18 @@ void free_search_index(SearchIndex* index) { SearchResult* search_index(SearchIndex* index, const char* query, float cutoff, int* num_results) { if (!index || !query || !num_results) return NULL; + // Validate input string length + if (strlen(query) >= MAX_STRING_LEN) { + *num_results = 0; + return NULL; + } + + // Validate cutoff + if (cutoff < 0.0f || cutoff > 1.0f) { + *num_results = 0; + return NULL; + } + // Allocate temporary array for results SearchResult* temp_results = (SearchResult*)malloc(index->num_strings * sizeof(SearchResult)); if (!temp_results) return NULL; @@ -192,7 +217,16 @@ SearchResult* search_index(SearchIndex* index, const char* query, float cutoff, float similarity = calculate_similarity(query, index->strings[i], cutoff); if (similarity >= cutoff) { - temp_results[*num_results].string = index->strings[i]; + // Store a copy of the string in the result + temp_results[*num_results].string = strdup(index->strings[i]); + if (!temp_results[*num_results].string) { + // Free any already allocated strings on error + for (int j = 0; j < *num_results; j++) { + free(temp_results[j].string); + } + free(temp_results); + return NULL; + } temp_results[*num_results].similarity = similarity; (*num_results)++; } @@ -204,6 +238,10 @@ SearchResult* search_index(SearchIndex* index, const char* query, float cutoff, // Allocate final result array with exact size SearchResult* results = (SearchResult*)malloc(*num_results * sizeof(SearchResult)); if (!results) { + // Free all strings in temp_results + for (int i = 0; i < *num_results; i++) { + free(temp_results[i].string); + } free(temp_results); return NULL; } @@ -217,5 +255,11 @@ SearchResult* search_index(SearchIndex* index, const char* query, float cutoff, // Free the search results void free_search_results(SearchResult* results, int num_results) { + if (!results) return; + + // Free all strings in the results + for (int i = 0; i < num_results; i++) { + free(results[i].string); + } free(results); } \ No newline at end of file diff --git a/similarity_search.h b/similarity_search.h index 9f12c33..613dd99 100644 --- a/similarity_search.h +++ b/similarity_search.h @@ -14,6 +14,7 @@ extern "C" { typedef struct { char **strings; int num_strings; + int capacity; } SearchIndex; // Structure to hold a search result diff --git a/similarity_search_addon.cc b/similarity_search_addon.cc index 6d68880..362ad84 100644 --- a/similarity_search_addon.cc +++ b/similarity_search_addon.cc @@ -66,7 +66,24 @@ Napi::Value SearchIndexWrapper::AddString(const Napi::CallbackInfo& info) { } std::string str = info[0].As().Utf8Value(); + + // Check if string is too long + if (str.length() >= MAX_STRING_LEN) { + Napi::Error::New(env, "String too long").ThrowAsJavaScriptException(); + return env.Null(); + } + + // Check if we've reached capacity + if (this->index_->num_strings >= this->index_->capacity) { + Napi::Error::New(env, "Search index capacity exceeded").ThrowAsJavaScriptException(); + return env.Null(); + } + int result = add_string_to_index(this->index_, str.c_str()); + if (result != 0) { + Napi::Error::New(env, "Failed to add string to index").ThrowAsJavaScriptException(); + return env.Null(); + } return Napi::Number::New(env, result); } @@ -81,10 +98,20 @@ Napi::Value SearchIndexWrapper::Search(const Napi::CallbackInfo& info) { } std::string query = info[0].As().Utf8Value(); - float cutoff = 0.2f; // Default cutoff + // Check if query string is too long + if (query.length() >= MAX_STRING_LEN) { + Napi::Error::New(env, "Query string too long").ThrowAsJavaScriptException(); + return env.Null(); + } + + float cutoff = 0.2f; // Default cutoff if (info.Length() > 1 && info[1].IsNumber()) { cutoff = info[1].As().FloatValue(); + if (cutoff < 0.0f || cutoff > 1.0f) { + Napi::Error::New(env, "Cutoff must be between 0 and 1").ThrowAsJavaScriptException(); + return env.Null(); + } } int num_results = 0;