TP-68136 | code critic workflow added

This commit is contained in:
Aman Chaturvedi
2024-08-21 19:07:04 +05:30
parent 7a9c4dccfa
commit 9f79c225a2

View File

@@ -33,28 +33,6 @@ def install_packages_from_file(filename: str):
except Exception as e:
print("")
def extract_code(directory: str) -> List[Tuple[str, str]]:
code_snippets = []
include_terms = {'service', 'controller', 'listener', 'scheduler', 'utils', 'client', 'repository', 'dao'}
for root, dirs, files in os.walk(directory):
# Skip test directories
if 'test' in root.lower():
continue
for file in files:
if file.endswith(('.java', '.kt')):
file_path = os.path.join(root, file)
# Only include directories with include terms
if not any(term in file_path.lower() for term in include_terms):
continue
with open(file_path, 'r', encoding='utf-8') as f:
code = f.read()
# Remove import statements
code = re.sub(r'^import .*$', '', code, flags=re.MULTILINE)
code_snippets.append((file_path, code))
return code_snippets
# Function to classify code snippets
def classify_code(code_snippets: List[Tuple[str, str]]) -> List[Tuple[str, str, str]]:
classified_code = []
@@ -84,7 +62,7 @@ def create_system_prompt() -> str:
def create_gpt_prompt(file_path: str, code: str, class_type: str) -> str:
encoded_prompt = os.getenv('ENCODED_CODE_REVIEW_USER_PROMPT', 'Q29kZToKe2NvZGV9')
encoded_prompt = os.getenv('ENCODED_CODE_REVIEW_USER_PROMPT', 'UmV2aWV3IHRoZSBmb2xsb3dpbmcgSlMvSlNYL1RTL1RTWCBjb2RlIGZvciBtYWpvciBwb3RlbnRpYWwgaXNzdWVzIHJlbGF0ZWQgdG8gdGhlc2UgZmlsZXMuIEtlZXAgdGhlc2UgdGhpbmdzIGluIG1pbmQ6IEF2b2lkIG1hZ2ljIG51bWJlcnMvc3RyaW5ncywgdXNlIGNvbnN0YW50cyBvciBlbnVtczsgVXNlIGRlc2NyaXB0aXZlIG5hbWVzIGZvciB2YXJpYWJsZXMvZnVuY3Rpb25zOyBVdGlsaXplIFR5cGVTY3JpcHQgZm9yIHN0cm9uZ2VyIHR5cGluZzsgTWFuYWdlIHotaW5kZXggYW5kIGNvbG9yIHZhbHVlcyBjZW50cmFsbHk7IFVzZSBzZXJ2ZXIgdGltZSBvdmVyIERhdGUubm93KCk7IEFwcGx5IG9wdGlvbmFsIGNoYWluaW5nIGZvciBudWxsaXNoIHZhbHVlczsgRW1icmFjZSBmdW5jdGlvbmFsIHByb2dyYW1taW5nOyBSZXVzZSBleGlzdGluZyBmdW5jdGlvbnMvdXRpbGl0aWVzOyBSZW1vdmUgY2FyZXQgaW4gcGFja2FnZS5qc29uIHRvIGxvY2sgdmVyc2lvbnM7IFVzZSB1c2VNZW1vIGFuZCB1c2VDYWxsYmFjayBpbiBSZWFjdDsgQXZvaWQgZGVmaW5pbmcgY29tcG9uZW50cyB3aXRoaW4gcmVuZGVyIGZ1bmN0aW9uczsgRXh0cmFjdCBmdW5jdGlvbnMgZnJvbSByZW5kZXIgdG8gcHJldmVudCByZS1kZWNsYXJhdGlvbnM7IFVzZSBDU1MgY2xhc3NlcyBvciBzdHlsZWQtY29tcG9uZW50cyBvdmVyIGlubGluZSBzdHlsZXM7IE1vdmUgZXZlbnQgaGFuZGxlcnMgb3V0c2lkZSBKU1g7IEV4dHJhY3QgY29uZGl0aW9ucyBvdXRzaWRlIEpTWCBpZiBjb21wbGV4OyBBdm9pZCAnYW55JyB0eXBlIGluIFR5cGVTY3JpcHQ7IFVzZSAnY29uc3QnIGZvciBjb25zdGFudHMsICdsZXQnIGZvciBtdXRhYmxlIHZhcmlhYmxlczsgSW1wbGVtZW50IGVycm9yIGhhbmRsaW5nOyBCZSBtaW5kZnVsIG9mIG9wdGltaXphdGlvbiwgYXZvaWQgdW5uZWNlc3NhcnkgY29tcHV0YXRpb25zL0RPTSB1cGRhdGVzOyBQcmV2ZW50IG1lbW9yeSBsZWFrcyBieSBjbGVhbmluZyB1cCBldmVudCBsaXN0ZW5lcnMvc3Vic2NyaXB0aW9ucy4uIEJlIGNvbmNyZXRlIGluIHlvdXIgcmVzcG9uc2UgYW5kIGdpdmUgdG8tdGhlLXBvaW50IGRlc2NyaXB0aW9uIGFuZCBmaXhlcyBpbiBtYXggMi0zIGxpbmVzIGZvciBldmVyeSBpc3N1ZS4gSWYgeW91IGRvbid0IGZpbmQgYW55IGlzc3VlcyBpbiB0aGUgY29kZSwganVzdCBnaXZlICJObyBtYWpvciBpc3N1ZXMgZm91bmQiIGFuZCBkb24ndCBnaXZlIGFueSB1bm5lY2Vzc2FyeSBzdWdnZXN0aW9ucyBpbiB0aGF0IGNhc2UuIFRoZSBvdXRwdXQgc2hvdWxkIGJlIGZvcm1hdHRlZCBhcyBhIEdpdEh1YiBQUiBjb21tZW50LgoKRmlsZSBQYXRoOgp7ZmlsZV9wYXRofQoKQ2xhc3MgVHlwZToKe2NsYXNzX3R5cGV9CgpDb2RlOgp7Y29kZX0=')
decoded_bytes = base64.b64decode(encoded_prompt)
user_prompt = decoded_bytes.decode('utf-8')
filled_prompt = user_prompt.format(file_path=file_path, class_type=class_type, code=code)
@@ -145,13 +123,17 @@ def extract_code_from_diff(diff_lines: List[str]) -> List[Tuple[str, str]]:
for line in diff_lines:
if line.startswith('+++ b/'):
if file_path and code:
code_snippets.append((file_path, code))
# Check if the file extension is one of the specified types before adding
if file_path.endswith(('.js', '.jsx', '.ts', '.tsx')):
code_snippets.append((file_path, code))
code = ""
file_path = line[6:]
elif line.startswith('+') and not line.startswith('++'):
code += line[1:] + '\n'
if file_path and code:
code_snippets.append((file_path, code))
# Check if the file extension is one of the specified types before adding
if file_path.endswith(('.js', '.jsx', '.ts', '.tsx')):
code_snippets.append((file_path, code))
return code_snippets
@@ -201,13 +183,10 @@ def run_analysis(directory, output_file, max_workers):
return
print("Running in mode : ", review_mode)
if review_mode == 'pr':
base_branch = os.getenv('BASE_BRANCH', 'master')
diff_lines = get_pr_diff(directory, base_branch)
print("diff ", diff_lines)
code_snippets = extract_code_from_diff(diff_lines)
else:
code_snippets = extract_code(directory)
base_branch = os.getenv('BASE_BRANCH', 'master')
diff_lines = get_pr_diff(directory, base_branch)
print("diff ", diff_lines)
code_snippets = extract_code_from_diff(diff_lines)
print("Identified code snippets list of size ", len(code_snippets))
classified_code = classify_code(code_snippets)